BobVLM-1.5b / model.py
selfDotOsman's picture
Update model.py
867324d verified
import torch
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, CLIPVisionModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BobVMLAdapter(torch.nn.Module):
def __init__(self, lang_embed_dim, clip_dim):
super().__init__()
self.activation = torch.nn.ReLU()
self.layer1 = torch.nn.Linear(clip_dim, 500)
self.layer2 = torch.nn.Linear(500,500)
self.layer3 = torch.nn.Linear(500, lang_embed_dim)
def forward(self,x):
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
x = self.activation(x)
x = self.layer3(x)
output = self.activation(x)
return output
class BobVLMConfig(PretrainedConfig):
def __init__(
self,
lang_embed_dim=2048,
clip_dim=1024,
):
self.lang_embed_dim = lang_embed_dim
self.clip_dim = clip_dim
def to_dict(self):
"""Convert config to dictionary format."""
return {k: v for k, v in self.__dict__.items()}
def to_dict(self):
"""Convert config to dictionary format."""
return {k: v for k, v in self.__dict__.items()}
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""Create config from dictionary."""
config = cls()
for key, value in config_dict.items():
setattr(config, key, value)
return config, kwargs
class BobVLM(PreTrainedModel):
config_class = BobVLMConfig
def __init__(self,config, **kwargs):
super().__init__(config)
self.vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
self.adapter = BobVMLAdapter(config.lang_embed_dim, config.clip_dim).to(device)
self.language_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B-Instruct').to(device)
def __extend_attention_mask(self,atten_mask, atten_to_img=True, num_added_tokens=257):
batch_size, original_seq_length = atten_mask.shape
# Create a new attention mask with the same initial mask and added tokens
if atten_to_img:
extended_mask = torch.ones(
batch_size,
original_seq_length + num_added_tokens,
dtype=atten_mask.dtype,
device=atten_mask.device
)
else:
extended_mask = torch.zeros(
batch_size,
original_seq_length + num_added_tokens,
dtype=atten_mask.dtype,
device=atten_mask.device
)
# Copy the original attention mask to the first part
extended_mask[:, -original_seq_length:] = atten_mask
return extended_mask
def process_inputs(self, input_ids, attention_mask, pixel_values,attend_to_img_tokens=True):
# Process language inputs
if input_ids is not None:
input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
final_embeddings = self.language_model.model.embed_tokens(input_ids).to(device)
#process visual inputs
if pixel_values is not None:
pixel_values = pixel_values.to(device)
vision_outputs = self.vit(pixel_values)
# Use the pooled output from CLIP vision transformer
image_embeddings = vision_outputs.last_hidden_state
# Pass image embeddings through adapter
adapted_image_embeddings = self.adapter(image_embeddings).to(device)
final_embeddings = torch.concat((adapted_image_embeddings,final_embeddings),axis=1).to(device)
attention_mask = self.__extend_attention_mask(attention_mask,atten_to_img=attend_to_img_tokens).to(device)
return final_embeddings,attention_mask
# print(attention_mask)
def forward(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True,labels=None,**kwargs):
input_ids = kwargs.get('input_ids', None) or input_ids
attention_mask = kwargs.get('attention_mask', None) or attention_mask
pixel_values = kwargs.get('pixel_values', None) or pixel_values
labels = kwargs.get('labels', None) or labels
# print(labels)
final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)
if labels is not None:
pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask,labels=labels)
else:
pred = self.language_model(inputs_embeds=final_embeddings,attention_mask=attention_mask)
return pred
def generate(self, input_ids = None, attention_mask=None, pixel_values=None, attend_to_img_tokens=True, max_new_tokens=50, temperature=0.3, top_p=0.9, **kwargs):
input_ids = kwargs.pop('input_ids', None) or input_ids
attention_mask = kwargs.pop('attention_mask', None) or attention_mask
pixel_values = kwargs.pop('pixel_values', None) or pixel_values
final_embeddings,attention_mask = self.process_inputs(input_ids,attention_mask,pixel_values,attend_to_img_tokens)
return self.language_model.generate(inputs_embeds=final_embeddings,attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
from huggingface_hub import login
login('hf'+'_PrmgudWKsTUwY'+'CkvPQIwSpreVKdnHmymEb')