Spaces:
Sleeping
Sleeping
Update vision_tower.py
Browse files- vision_tower.py +6 -12
vision_tower.py
CHANGED
|
@@ -53,9 +53,7 @@ class MLP_dim(nn.Module):
|
|
| 53 |
return self.net2(self.net1(x))
|
| 54 |
|
| 55 |
class FLIP_Dinov2Embeddings(Dinov2Embeddings):
|
| 56 |
-
|
| 57 |
-
Construct the CLS token, mask token, position and patch embeddings.
|
| 58 |
-
"""
|
| 59 |
|
| 60 |
def __init__(self, config: Dinov2Config) -> None:
|
| 61 |
super().__init__(config)
|
|
@@ -65,17 +63,15 @@ class FLIP_Dinov2Embeddings(Dinov2Embeddings):
|
|
| 65 |
target_dtype = self.patch_embeddings.projection.weight.dtype
|
| 66 |
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
| 67 |
|
| 68 |
-
|
| 69 |
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 70 |
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 71 |
|
| 72 |
-
|
| 73 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 74 |
|
| 75 |
if bool_masked_pos is not None:
|
| 76 |
-
|
| 77 |
-
# bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
|
| 78 |
-
# )
|
| 79 |
B,S,D = embeddings.shape
|
| 80 |
batch_indices = torch.arange(B).unsqueeze(1)
|
| 81 |
embeddings = embeddings[batch_indices, bool_masked_pos]
|
|
@@ -140,13 +136,11 @@ class DINOv2_MLP(nn.Module):
|
|
| 140 |
|
| 141 |
dino_outputs = self.dinov2(**img_inputs)
|
| 142 |
dino_seq = dino_outputs.last_hidden_state
|
| 143 |
-
|
| 144 |
-
# dino_seq = dino_seq.view(B*S,-1)
|
| 145 |
dino_seq = dino_seq[:,0,:]
|
| 146 |
|
| 147 |
down_sample_out = self.down_sampler(dino_seq)
|
| 148 |
-
|
| 149 |
-
# down_sample_out = down_sample_out[:,0,:]
|
| 150 |
|
| 151 |
return down_sample_out
|
| 152 |
|
|
|
|
| 53 |
return self.net2(self.net1(x))
|
| 54 |
|
| 55 |
class FLIP_Dinov2Embeddings(Dinov2Embeddings):
|
| 56 |
+
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def __init__(self, config: Dinov2Config) -> None:
|
| 59 |
super().__init__(config)
|
|
|
|
| 63 |
target_dtype = self.patch_embeddings.projection.weight.dtype
|
| 64 |
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
| 65 |
|
| 66 |
+
|
| 67 |
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 68 |
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 69 |
|
| 70 |
+
|
| 71 |
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 72 |
|
| 73 |
if bool_masked_pos is not None:
|
| 74 |
+
|
|
|
|
|
|
|
| 75 |
B,S,D = embeddings.shape
|
| 76 |
batch_indices = torch.arange(B).unsqueeze(1)
|
| 77 |
embeddings = embeddings[batch_indices, bool_masked_pos]
|
|
|
|
| 136 |
|
| 137 |
dino_outputs = self.dinov2(**img_inputs)
|
| 138 |
dino_seq = dino_outputs.last_hidden_state
|
| 139 |
+
|
|
|
|
| 140 |
dino_seq = dino_seq[:,0,:]
|
| 141 |
|
| 142 |
down_sample_out = self.down_sampler(dino_seq)
|
| 143 |
+
|
|
|
|
| 144 |
|
| 145 |
return down_sample_out
|
| 146 |
|