MohmedAnik commited on
Commit
8f4a471
·
verified ·
1 Parent(s): 21d3dc9

Update vision_tower.py

Browse files
Files changed (1) hide show
  1. vision_tower.py +6 -12
vision_tower.py CHANGED
@@ -53,9 +53,7 @@ class MLP_dim(nn.Module):
53
  return self.net2(self.net1(x))
54
 
55
  class FLIP_Dinov2Embeddings(Dinov2Embeddings):
56
- """
57
- Construct the CLS token, mask token, position and patch embeddings.
58
- """
59
 
60
  def __init__(self, config: Dinov2Config) -> None:
61
  super().__init__(config)
@@ -65,17 +63,15 @@ class FLIP_Dinov2Embeddings(Dinov2Embeddings):
65
  target_dtype = self.patch_embeddings.projection.weight.dtype
66
  embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
67
 
68
- # add the [CLS] token to the embedded patch tokens
69
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
70
  embeddings = torch.cat((cls_tokens, embeddings), dim=1)
71
 
72
- # add positional encoding to each token
73
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
74
 
75
  if bool_masked_pos is not None:
76
- # embeddings = torch.where(
77
- # bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
78
- # )
79
  B,S,D = embeddings.shape
80
  batch_indices = torch.arange(B).unsqueeze(1)
81
  embeddings = embeddings[batch_indices, bool_masked_pos]
@@ -140,13 +136,11 @@ class DINOv2_MLP(nn.Module):
140
 
141
  dino_outputs = self.dinov2(**img_inputs)
142
  dino_seq = dino_outputs.last_hidden_state
143
- # B,S,_ = dino_seq.shape
144
- # dino_seq = dino_seq.view(B*S,-1)
145
  dino_seq = dino_seq[:,0,:]
146
 
147
  down_sample_out = self.down_sampler(dino_seq)
148
- # down_sample_out = down_sample_out.view(B,S,-1)
149
- # down_sample_out = down_sample_out[:,0,:]
150
 
151
  return down_sample_out
152
 
 
53
  return self.net2(self.net1(x))
54
 
55
  class FLIP_Dinov2Embeddings(Dinov2Embeddings):
56
+
 
 
57
 
58
  def __init__(self, config: Dinov2Config) -> None:
59
  super().__init__(config)
 
63
  target_dtype = self.patch_embeddings.projection.weight.dtype
64
  embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
65
 
66
+
67
  cls_tokens = self.cls_token.expand(batch_size, -1, -1)
68
  embeddings = torch.cat((cls_tokens, embeddings), dim=1)
69
 
70
+
71
  embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
72
 
73
  if bool_masked_pos is not None:
74
+
 
 
75
  B,S,D = embeddings.shape
76
  batch_indices = torch.arange(B).unsqueeze(1)
77
  embeddings = embeddings[batch_indices, bool_masked_pos]
 
136
 
137
  dino_outputs = self.dinov2(**img_inputs)
138
  dino_seq = dino_outputs.last_hidden_state
139
+
 
140
  dino_seq = dino_seq[:,0,:]
141
 
142
  down_sample_out = self.down_sampler(dino_seq)
143
+
 
144
 
145
  return down_sample_out
146