Forbu14 commited on
Commit
2ddf390
·
verified ·
1 Parent(s): f38244c

Upload dc_3dunet_film.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dc_3dunet_film.py +269 -0
dc_3dunet_film.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import List
5
+ import math
6
+
7
+ # ==============================================================================
8
+ # == Conditioning Blocks
9
+ # ==============================================================================
10
+
11
+
12
+ class SinusoidalPosEmb(nn.Module):
13
+ def __init__(self, dim):
14
+ super().__init__()
15
+ self.dim = dim
16
+
17
+ def forward(self, x):
18
+ device = x.device
19
+ half_dim = self.dim // 2
20
+ if half_dim == 0:
21
+ # For dim=1, use sin
22
+ return torch.sin(x).unsqueeze(-1)
23
+ elif half_dim == 1:
24
+ # For dim=2, use sin and cos with scale 1
25
+ emb = x[:, None] * 1.0
26
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
27
+ else:
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
30
+ emb = x[:, None] * emb[None, :]
31
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
32
+ return emb
33
+
34
+
35
+ class FilmLayer(nn.Module):
36
+ def __init__(self, embedding_dim, num_channels):
37
+ super().__init__()
38
+ self.mlp = nn.Sequential(nn.Linear(embedding_dim, num_channels * 2), nn.ReLU())
39
+
40
+ def forward(self, x, context):
41
+ mlp_out = self.mlp(context)
42
+ scale = mlp_out[:, : x.shape[1]]
43
+ bias = mlp_out[:, x.shape[1] :]
44
+
45
+ scale = scale.view(x.shape[0], x.shape[1], 1, 1, 1)
46
+ bias = bias.view(x.shape[0], x.shape[1], 1, 1, 1)
47
+
48
+ return (1.0 + scale) * x + bias
49
+
50
+
51
+ # ==============================================================================
52
+ # == 3D U-Net Components
53
+ # ==============================================================================
54
+
55
+
56
+ class ResNetBlock3D(nn.Module):
57
+ """
58
+ A 3D ResNet block with FiLM conditioning.
59
+ """
60
+
61
+ def __init__(
62
+ self, in_channels: int, out_channels: int, embedding_dim: int, context_frames: int
63
+ ):
64
+ super().__init__()
65
+ self.context_frames = context_frames
66
+
67
+ self.conv1 = nn.Conv3d(
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size=(3, 3, 3),
71
+ padding=(1, 1, 1),
72
+ bias=False,
73
+ )
74
+ self.bn1 = nn.Identity() #nn.InstanceNorm3d(out_channels, affine=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.conv2 = nn.Conv3d(
77
+ out_channels,
78
+ out_channels,
79
+ kernel_size=(3, 3, 3),
80
+ padding=(1, 1, 1),
81
+ bias=False,
82
+ )
83
+ self.bn2 = nn.InstanceNorm3d(out_channels, affine=True)
84
+
85
+ self.film = FilmLayer(embedding_dim, out_channels)
86
+
87
+ self.shortcut = nn.Sequential()
88
+ if in_channels != out_channels:
89
+ self.shortcut = nn.Sequential(
90
+ nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
91
+ nn.InstanceNorm3d(out_channels, affine=True),
92
+ )
93
+
94
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
95
+ h = self.relu(self.bn1(self.conv1(x)))
96
+
97
+ # Apply FiLM only to the frames after context_frames
98
+ h_context = h[:, :, : self.context_frames, :, :]
99
+ h_noisy = h[:, :, self.context_frames :, :, :]
100
+
101
+ h_noisy_filmed = self.film(h_noisy, context)
102
+
103
+ h = torch.cat([h_context, h_noisy_filmed], dim=2)
104
+
105
+ h = self.bn2(self.conv2(h))
106
+ return self.relu(h + self.shortcut(x))
107
+
108
+
109
+ # ==============================================================================
110
+ # == Full 3D U-Net Architecture
111
+ # ==============================================================================
112
+ class UNet_DCAE_3D(nn.Module):
113
+ """
114
+ A 3D U-Net architecture that only performs spatial down/up-sampling, with FiLM conditioning.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_channels: int = 1,
120
+ out_channels: int = 1,
121
+ features: List[int] = [32, 64, 128, 256],
122
+ context_dim: int = 4,
123
+ embedding_dim: int = 128,
124
+ context_frames: int = 4,
125
+ num_additional_resnet_blocks: int = 0,
126
+ time_emb_dim: int = 64,
127
+ ):
128
+ super().__init__()
129
+ self.features = features
130
+ self.context_dim = context_dim
131
+ self.embedding_dim = embedding_dim
132
+ self.context_frames = context_frames
133
+ self.num_additional_resnet_blocks = num_additional_resnet_blocks
134
+ self.time_emb_dim = time_emb_dim
135
+
136
+ # --- Time Embedding ---
137
+ time_mlp_input_dim = context_dim - 1 + self.time_emb_dim
138
+ self.time_mlp = nn.Sequential(
139
+ nn.Linear(time_mlp_input_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim)
140
+ )
141
+ self.time_emb = SinusoidalPosEmb(dim=self.time_emb_dim)
142
+
143
+ self.encoder_convs = nn.ModuleList()
144
+ self.decoder_convs = nn.ModuleList()
145
+ self.downs = nn.ModuleList()
146
+
147
+ # --- Encoder (Downsampling Path) ---
148
+ current_channels = in_channels
149
+ for feature in features:
150
+ self.encoder_convs.append(
151
+ ResNetBlock3D(
152
+ current_channels, feature * 2, embedding_dim, self.context_frames
153
+ )
154
+ )
155
+ self.downs.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
156
+ current_channels = feature * 2
157
+
158
+ # --- Bottleneck ---
159
+ bottleneck_channels = features[-1] * 2
160
+ self.bottleneck = ResNetBlock3D(
161
+ bottleneck_channels, bottleneck_channels, embedding_dim, self.context_frames
162
+ )
163
+
164
+ # --- Decoder (Upsampling Path) ---
165
+ for feature in reversed(features):
166
+ self.decoder_convs.append(
167
+ ResNetBlock3D(feature * 4, feature, embedding_dim, self.context_frames)
168
+ )
169
+
170
+ self.additional_resnet_blocks = nn.ModuleList()
171
+ for feature in reversed(features):
172
+ blocks = nn.ModuleList()
173
+ for _ in range(self.num_additional_resnet_blocks):
174
+ blocks.append(
175
+ ResNetBlock3D(feature, feature, embedding_dim, self.context_frames)
176
+ )
177
+ self.additional_resnet_blocks.append(blocks)
178
+
179
+ # --- Final Output Layer ---
180
+ self.final_conv = nn.Conv3d(
181
+ features[0], out_channels, kernel_size=(1, 1, 1)
182
+ )
183
+
184
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
185
+ time_val = t[:, -1]
186
+ emb = self.time_emb(time_val)
187
+ spatial = t[:, :-1]
188
+ combined = torch.cat([spatial, emb], dim=1)
189
+ context = self.time_mlp(combined)
190
+ skip_connections = []
191
+
192
+ # --- Encoder Path ---
193
+ for i in range(len(self.features)):
194
+
195
+ x = self.encoder_convs[i](x, context)
196
+ skip_connections.append(x)
197
+ x = self.downs[i](x)
198
+
199
+ # --- Bottleneck ---
200
+ x = self.bottleneck(x, context)
201
+
202
+ # --- Decoder Path ---
203
+ skip_connections = skip_connections[::-1]
204
+ for i in range(len(self.decoder_convs)):
205
+
206
+ x = F.interpolate(x, scale_factor=(1, 2, 2), mode='nearest')
207
+ skip_connection = skip_connections[i]
208
+
209
+ if x.shape != skip_connection.shape:
210
+ x = F.interpolate(x, size=skip_connection.shape[2:])
211
+
212
+ concat_skip = torch.cat((skip_connection, x), dim=1)
213
+ x = self.decoder_convs[i](concat_skip, context)
214
+
215
+ for block in self.additional_resnet_blocks[i]:
216
+ x = block(x, context)
217
+
218
+ return self.final_conv(x)
219
+
220
+
221
+ # --- Example Usage ---
222
+ if __name__ == "__main__":
223
+ print(
224
+ "--- Testing Full 3D U-Net with DC-AE, ResNet Blocks, and FiLM conditioning ---"
225
+ )
226
+
227
+ # Define model parameters
228
+ CONTEXT_FRAMES = 4
229
+ IMG_DEPTH = CONTEXT_FRAMES + 2
230
+ IMG_HEIGHT, IMG_WIDTH = 128, 128
231
+ IN_CHANNELS = 3
232
+ OUT_CHANNELS = 3
233
+ BATCH_SIZE = 2
234
+ CONTEXT_DIM = 128
235
+
236
+ # Create a random input tensor (N, C, D, H, W)
237
+ input_tensor = torch.randn(
238
+ BATCH_SIZE, IN_CHANNELS, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH
239
+ )
240
+ t = torch.rand(BATCH_SIZE, CONTEXT_DIM)
241
+ print(f"Input shape: {input_tensor.shape}")
242
+ print(f"Time shape: {t.shape}")
243
+
244
+ # Initialize the model
245
+ model = UNet_DCAE_3D(
246
+ in_channels=IN_CHANNELS,
247
+ out_channels=OUT_CHANNELS,
248
+ features=[64, 128, 256],
249
+ context_dim=CONTEXT_DIM,
250
+ embedding_dim=128,
251
+ context_frames=CONTEXT_FRAMES,
252
+ num_additional_resnet_blocks=3
253
+ )
254
+
255
+ # Perform a forward pass
256
+ output_tensor = model(input_tensor, t)
257
+
258
+ print(f"Output shape: {output_tensor.shape}")
259
+
260
+ # Verify the output shape is as expected
261
+ expected_shape = (BATCH_SIZE, OUT_CHANNELS, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH)
262
+ assert output_tensor.shape == expected_shape, (
263
+ f"Shape mismatch! Expected {expected_shape}, got {output_tensor.shape}"
264
+ )
265
+
266
+ print("✅ 3D U-Net model shape test PASSED.")
267
+
268
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
269
+ print(f"Total trainable parameters: {num_params:,}")