File size: 5,488 Bytes
9ce23b5 867324d 9ce23b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, CLIPVisionModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BobVMLAdapter(torch.nn.Module):
def __init__(self, lang_embed_dim, clip_dim):
super().__init__()
self.activation = torch.nn.ReLU()
self.layer1 = torch.nn.Linear(clip_dim, 500)
self.layer2 = torch.nn.Linear(500,500)
self.layer3 = torch.nn.Linear(500, lang_embed_dim)
def forward(self,x):
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
x = self.activation(x)
x = self.layer3(x)
output = self.activation(x)
return output
class BobVLMConfig(PretrainedConfig):
def __init__(
self,
lang_embed_dim=2048,
clip_dim=1024,
):
self.lang_embed_dim = lang_embed_dim
self.clip_dim = clip_dim
def to_dict(self):
"""Convert config to dictionary format."""
return {k: v for k, v in self.__dict__.items()}
def to_dict(self):
"""Convert config to dictionary format."""
return {k: v for k, v in self.__dict__.items()}
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""Create config from dictionary."""
config = cls()
for key, value in config_dict.items():
setattr(config, key, value)
return config, kwargs
class BobVLM(PreTrainedModel):
config_class = BobVLMConfig
def __init__(self,config, **kwargs):
super().__init__(config)
self.vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
self.adapter = BobVMLAdapter(config.lang_embed_dim, config.clip_dim).to(device)
self.language_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct').to(device)
def __extend_attention_mask(self,atten_mask, atten_to_img=True, num_added_tokens=257):
batch_size, original_seq_length = atten_mask.shape
# Create a new attention mask with the same initial mask and added tokens
if atten_to_img:
extended_mask = torch.ones(
batch_size,
original_seq_length + num_added_tokens,
dtype=atten_mask.dtype,
device=atten_mask.device
)
else:
extended_mask = torch.zeros(
batch_size,
original_seq_length + num_added_tokens,
dtype=atten_mask.dtype,
device=atten_mask.device
)
# Copy the original attention mask to the first part
extended_mask[:, -original_seq_length:] = atten_mask
return extended_mask
def process_inputs(self, input_ids, attention_mask, pixel_values,attend_to_img_tokens=True):
# Process language inputs
if input_ids is not None:
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
final_embeddings = self.language_model.model.embed_tokens(input_ids).to(device)
#process visual inputs
if pixel_values is not None:
pixel_values = pixel_values.to(device)
vision_outputs = self.vit(pixel_values)
# Use the pooled output from CLIP vision transformer
image_embeddings = vision_outputs.last_hidden_state
# Pass image embeddings through adapter
adapted_image_embeddings = self.adapter(image_embeddings).to(device)
final_embeddings = torch.concat((adapted_image_embeddings,final_embeddings),axis=1).to(device)
attention_mask = self.__extend_attention_mask(attention_mask,atten_to_img=attend_to_img_tokens).to(device)
return final_embeddings,attention_mask
# print(attention_mask)
def forward(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True,labels=None,**kwargs):
input_ids = kwargs.get('input_ids', None) or input_ids
attention_mask = kwargs.get('attention_mask', None) or attention_mask
pixel_values = kwargs.get('pixel_values', None) or pixel_values
labels = kwargs.get('labels', None) or labels
# print(labels)
final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)
if labels is not None:
pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask,labels=labels)
else:
pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask)
return pred
def generate(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True, max_new_tokens=50, temperature=0.3, top_p=0.9, **kwargs):
input_ids = kwargs.pop('input_ids', None) or input_ids
attention_mask = kwargs.pop('attention_mask', None) or attention_mask
pixel_values = kwargs.pop('pixel_values', None) or pixel_values
final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)
return self.language_model.generate(inputs_embeds=final_embeddings,attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
from huggingface_hub import login
login('hf'+'_PrmgudWKsTUwY'+'CkvPQIwSpreVKdnHmymEb')
|