hackergeek98 commited on
Commit
54f2cdd
·
verified ·
1 Parent(s): 9a5c0b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -14
app.py CHANGED
@@ -1,26 +1,116 @@
1
- # app.p
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModel, AutoFeatureExtractor
4
  import torch
5
  from PIL import Image
 
 
 
 
6
 
7
- # Load the tokenizer and model
8
- tokenizer = AutoTokenizer.from_pretrained("hackergeek/radiology-image-captioning")
9
- model = AutoModel.from_pretrained("hackergeek/radiology-image-captioning")
10
- feature_extractor = AutoFeatureExtractor.from_pretrained("hackergeek/radiology-image-captioning")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Function to generate caption
13
  def generate_caption(image):
14
  # Preprocess image
15
- inputs = feature_extractor(images=image, return_tensors="pt")
16
-
17
- # Generate features
18
  with torch.no_grad():
19
- outputs = model.generate(**inputs)
20
-
21
- # Decode the output tokens
22
- caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
- return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Create Gradio interface
26
  interface = gr.Interface(
 
1
+ # app.py
2
  import gradio as gr
 
3
  import torch
4
  from PIL import Image
5
+ import json
6
+ import os
7
+ from tokenizers import ByteLevelBPETokenizer # Changed from Tokenizer
8
+ from torchvision import transforms
9
 
10
+ # ============================================================
11
+ # 3. Model (Tiny CNN + Transformer Decoder) - Re-define model classes
12
+ # ============================================================
13
+ class CNNEncoder(torch.nn.Module):
14
+ def __init__(self, embed_dim=128):
15
+ super().__init__()
16
+ self.cnn = torch.nn.Sequential(
17
+ torch.nn.Conv2d(3, 32, 3, 2, 1), torch.nn.ReLU(),
18
+ torch.nn.Conv2d(32, 64, 3, 2, 1), torch.nn.ReLU(),
19
+ torch.nn.Conv2d(64, 128, 3, 2, 1), torch.nn.ReLU(),
20
+ torch.nn.AdaptiveAvgPool2d((1,1))
21
+ )
22
+ self.fc = torch.nn.Linear(128, embed_dim)
23
+ def forward(self, x):
24
+ x = self.cnn(x)
25
+ x = x.view(x.size(0), -1)
26
+ return self.fc(x)
27
+
28
+ class TransformerDecoder(torch.nn.Module):
29
+ def __init__(self, vocab_size, embed_dim=128, nhead=4, num_layers=2, max_len=40):
30
+ super().__init__()
31
+ self.embed = torch.nn.Embedding(vocab_size, embed_dim)
32
+ decoder_layer = torch.nn.TransformerDecoderLayer(d_model=embed_dim, nhead=nhead)
33
+ self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
34
+ self.fc_out = torch.nn.Linear(embed_dim, vocab_size)
35
+ self.pos_embed = torch.nn.Embedding(max_len, embed_dim)
36
+
37
+ def forward(self, tgt, memory):
38
+ positions = torch.arange(0, tgt.shape[1], device=tgt.device).unsqueeze(0)
39
+ tgt_emb = self.embed(tgt) + self.pos_embed(positions)
40
+ memory = memory.unsqueeze(0)
41
+ out = self.decoder(tgt_emb.transpose(0,1), memory)
42
+ return self.fc_out(out.transpose(0,1))
43
+
44
+ class ImageCaptionModel(torch.nn.Module):
45
+ def __init__(self, vocab_size, embed_dim=128):
46
+ super().__init__()
47
+ self.encoder = CNNEncoder(embed_dim)
48
+ self.decoder = TransformerDecoder(vocab_size, embed_dim)
49
+ def forward(self, images, captions):
50
+ feats = self.encoder(images)
51
+ return self.decoder(captions, feats)
52
+
53
+ # ============================================================
54
+ # Load the tokenizer and model manually
55
+ # ============================================================
56
+
57
+ # Load config
58
+ with open("radiology_caption_model/config.json", "r") as f:
59
+ config = json.load(f)
60
+
61
+ # Load tokenizer - Corrected to use ByteLevelBPETokenizer with both files
62
+ tokenizer = ByteLevelBPETokenizer("radiology_caption_model/vocab.json", "radiology_caption_model/merges.txt")
63
+
64
+ # Instantiate the model with config parameters
65
+ model = ImageCaptionModel(
66
+ vocab_size=config["vocab_size"],
67
+ embed_dim=config["embed_dim"]
68
+ )
69
+
70
+ # Load the model weights
71
+ model.load_state_dict(torch.load("radiology_caption_model/pytorch_model.bin", map_location=torch.device('cpu')))
72
+ model.eval() # Set model to evaluation mode
73
+
74
+ # Define image transformations
75
+ image_size = 128 # Must match training image size
76
+ img_transforms = transforms.Compose([
77
+ transforms.Resize((image_size, image_size)),
78
+ transforms.ToTensor(),
79
+ transforms.Normalize(mean=[0.485,0.456,0.406],
80
+ std=[0.229,0.224,0.225]),
81
+ ])
82
 
83
  # Function to generate caption
84
  def generate_caption(image):
85
  # Preprocess image
86
+ img_tensor = img_transforms(image).unsqueeze(0) # Add batch dimension
87
+
88
+ # Generate caption
89
  with torch.no_grad():
90
+ # Get image features
91
+ image_features = model.encoder(img_tensor)
92
+
93
+ # Start caption generation with BOS token
94
+ # (We assume BOS token ID is 2 from tokenizer training in cell 1)
95
+ # (Padding token ID is 0)
96
+ caption_tokens = [tokenizer.token_to_id("[BOS]")]
97
+ max_len = config["max_len"] if "max_len" in config else 40 # Use max_len from config, fallback to 40
98
+
99
+ for _ in range(max_len - 1): # -1 because BOS is already there
100
+ input_tokens = torch.tensor(caption_tokens).unsqueeze(0) # Add batch dimension
101
+ output = model.decoder(input_tokens, image_features)
102
+ last_token_logits = output[0, -1, :]
103
+ predicted_token_id = torch.argmax(last_token_logits).item()
104
+
105
+ caption_tokens.append(predicted_token_id)
106
+
107
+ # Stop if EOS token is generated
108
+ if predicted_token_id == tokenizer.token_to_id("[EOS]"):
109
+ break
110
+
111
+ # Decode the output tokens, excluding BOS and EOS (if present)
112
+ decoded_caption = tokenizer.decode(caption_tokens[1:-1] if caption_tokens[-1] == tokenizer.token_to_id("[EOS]") else caption_tokens[1:])
113
+ return decoded_caption
114
 
115
  # Create Gradio interface
116
  interface = gr.Interface(