Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from sync_models.modules import * | |
| class Transformer_RGB(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.net_vid = self.build_net_vid() | |
| self.ff_vid = nn.Sequential( | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 1024) | |
| ) | |
| self.pos_encoder = PositionalEncoding_RGB(d_model=512) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) | |
| self.net_aud = self.build_net_aud() | |
| self.lstm = nn.LSTM(512, 256, num_layers=1, bidirectional=True, batch_first=True) | |
| self.ff_aud = NetFC_2D(input_dim=512, hidden_dim=512, embed_dim=1024) | |
| self.logits_scale = nn.Linear(1, 1, bias=False) | |
| torch.nn.init.ones_(self.logits_scale.weight) | |
| self.fc = nn.Linear(1,1) | |
| def build_net_vid(self): | |
| layers = [ | |
| { | |
| 'type': 'conv3d', | |
| 'n_channels': 64, | |
| 'kernel_size': (5, 7, 7), | |
| 'stride': (1, 3, 3), | |
| 'padding': (0), | |
| 'maxpool': { | |
| 'kernel_size': (1, 3, 3), | |
| 'stride': (1, 2, 2) | |
| } | |
| }, | |
| { | |
| 'type': 'conv3d', | |
| 'n_channels': 128, | |
| 'kernel_size': (1, 5, 5), | |
| 'stride': (1, 2, 2), | |
| 'padding': (0, 0, 0), | |
| }, | |
| { | |
| 'type': 'conv3d', | |
| 'n_channels': 256, | |
| 'kernel_size': (1, 3, 3), | |
| 'stride': (1, 2, 2), | |
| 'padding': (0, 1, 1), | |
| }, | |
| { | |
| 'type': 'conv3d', | |
| 'n_channels': 256, | |
| 'kernel_size': (1, 3, 3), | |
| 'stride': (1, 1, 2), | |
| 'padding': (0, 1, 1), | |
| }, | |
| { | |
| 'type': 'conv3d', | |
| 'n_channels': 256, | |
| 'kernel_size': (1, 3, 3), | |
| 'stride': (1, 1, 1), | |
| 'padding': (0, 1, 1), | |
| 'maxpool': { | |
| 'kernel_size': (1, 3, 3), | |
| 'stride': (1, 2, 2) | |
| } | |
| }, | |
| { | |
| 'type': 'fc3d', | |
| 'n_channels': 512, | |
| 'kernel_size': (1, 4, 4), | |
| 'stride': (1, 1, 1), | |
| 'padding': (0), | |
| }, | |
| ] | |
| return VGGNet(n_channels_in=3, layers=layers) | |
| def build_net_aud(self): | |
| layers = [ | |
| { | |
| 'type': 'conv2d', | |
| 'n_channels': 64, | |
| 'kernel_size': (3, 3), | |
| 'stride': (2, 2), | |
| 'padding': (1, 1), | |
| 'maxpool': { | |
| 'kernel_size': (3, 3), | |
| 'stride': (2, 2) | |
| } | |
| }, | |
| { | |
| 'type': 'conv2d', | |
| 'n_channels': 192, | |
| 'kernel_size': (3, 3), | |
| 'stride': (1, 2), | |
| 'padding': (1, 1), | |
| 'maxpool': { | |
| 'kernel_size': (3, 3), | |
| 'stride': (2, 2) | |
| } | |
| }, | |
| { | |
| 'type': 'conv2d', | |
| 'n_channels': 384, | |
| 'kernel_size': (3, 3), | |
| 'stride': (1, 1), | |
| 'padding': (1, 1), | |
| }, | |
| { | |
| 'type': 'conv2d', | |
| 'n_channels': 256, | |
| 'kernel_size': (3, 3), | |
| 'stride': (1, 1), | |
| 'padding': (1, 1), | |
| }, | |
| { | |
| 'type': 'conv2d', | |
| 'n_channels': 256, | |
| 'kernel_size': (3, 3), | |
| 'stride': (1, 1), | |
| 'padding': (1, 1), | |
| 'maxpool': { | |
| 'kernel_size': (2, 3), | |
| 'stride': (2, 2) | |
| } | |
| }, | |
| { | |
| 'type': 'fc2d', | |
| 'n_channels': 512, | |
| 'kernel_size': (4, 2), | |
| 'stride': (1, 1), | |
| 'padding': (0, 0), | |
| }, | |
| ] | |
| return VGGNet(n_channels_in=1, layers=layers) | |
| def forward_vid(self, x, return_feats=False): | |
| out_conv = self.net_vid(x).squeeze(-1).squeeze(-1) | |
| # print("Conv: ", out_conv.shape) # Bx1024x21x1x1 | |
| out = self.pos_encoder(out_conv.transpose(1,2)) | |
| out_trans = self.transformer_encoder(out) | |
| # print("Transformer: ", out_trans.shape) # Bx21x1024 | |
| out = self.ff_vid(out_trans).transpose(1,2) | |
| # print("MLP output: ", out.shape) # Bx1024 | |
| if return_feats: | |
| return out, out_conv | |
| else: | |
| return out | |
| def forward_aud(self, x): | |
| out = self.net_aud(x) | |
| out = self.ff_aud(out) | |
| out = out.squeeze(-1) | |
| return out | |