File size: 5,488 Bytes
9ce23b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867324d
9ce23b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, CLIPVisionModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



class BobVMLAdapter(torch.nn.Module):
    def __init__(self, lang_embed_dim, clip_dim):
        super().__init__()
        self.activation = torch.nn.ReLU()
        self.layer1 = torch.nn.Linear(clip_dim, 500)
        self.layer2 = torch.nn.Linear(500,500)
        self.layer3 = torch.nn.Linear(500, lang_embed_dim)

    def forward(self,x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.layer2(x)
        x = self.activation(x)
        x = self.layer3(x)
        output = self.activation(x)

        return output


class BobVLMConfig(PretrainedConfig):

    def __init__(
        self,
        lang_embed_dim=2048,
        clip_dim=1024,
    ):
        self.lang_embed_dim = lang_embed_dim
        self.clip_dim = clip_dim

    def to_dict(self):
        """Convert config to dictionary format."""
        return {k: v for k, v in self.__dict__.items()}

    def to_dict(self):
        """Convert config to dictionary format."""
        return {k: v for k, v in self.__dict__.items()}

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        """Create config from dictionary."""
        config = cls()
        for key, value in config_dict.items():
            setattr(config, key, value)
        return config, kwargs

class BobVLM(PreTrainedModel):
    config_class = BobVLMConfig
    def __init__(self,config, **kwargs):
        super().__init__(config)
        self.vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
        self.adapter = BobVMLAdapter(config.lang_embed_dim, config.clip_dim).to(device)
        self.language_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct').to(device)


    def __extend_attention_mask(self,atten_mask, atten_to_img=True, num_added_tokens=257):
        batch_size, original_seq_length = atten_mask.shape


        # Create a new attention mask with the same initial mask and added tokens
        if atten_to_img:
            extended_mask = torch.ones(
                batch_size,
                original_seq_length + num_added_tokens,
                dtype=atten_mask.dtype,
                device=atten_mask.device
            )
        else:
                extended_mask = torch.zeros(
                batch_size,
                original_seq_length + num_added_tokens,
                dtype=atten_mask.dtype,
                device=atten_mask.device
            )
        # Copy the original attention mask to the first part
        extended_mask[:, -original_seq_length:] = atten_mask

        return extended_mask

    def process_inputs(self, input_ids, attention_mask, pixel_values,attend_to_img_tokens=True):
        # Process language inputs
        if input_ids is not None:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            final_embeddings = self.language_model.model.embed_tokens(input_ids).to(device)
            #process visual inputs
            if pixel_values is not None:
                pixel_values = pixel_values.to(device)
                vision_outputs = self.vit(pixel_values)
                # Use the pooled output from CLIP vision transformer
                image_embeddings = vision_outputs.last_hidden_state
                # Pass image embeddings through adapter
                adapted_image_embeddings = self.adapter(image_embeddings).to(device)
                final_embeddings = torch.concat((adapted_image_embeddings,final_embeddings),axis=1).to(device)
                attention_mask =  self.__extend_attention_mask(attention_mask,atten_to_img=attend_to_img_tokens).to(device)

            return final_embeddings,attention_mask


                # print(attention_mask)
    def forward(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True,labels=None,**kwargs):
        input_ids = kwargs.get('input_ids', None) or input_ids
        attention_mask = kwargs.get('attention_mask', None) or attention_mask
        pixel_values = kwargs.get('pixel_values', None) or pixel_values
        labels = kwargs.get('labels', None) or labels
        # print(labels)


        final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)

        if labels is not None:
            pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask,labels=labels)

        else:
            pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask)
        return pred

    def generate(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True, max_new_tokens=50, temperature=0.3, top_p=0.9, **kwargs):
        input_ids = kwargs.pop('input_ids', None) or input_ids
        attention_mask = kwargs.pop('attention_mask', None) or attention_mask
        pixel_values = kwargs.pop('pixel_values', None) or pixel_values

        final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)
        return self.language_model.generate(inputs_embeds=final_embeddings,attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)

from huggingface_hub import login
login('hf'+'_PrmgudWKsTUwY'+'CkvPQIwSpreVKdnHmymEb')