hackergeek98 commited on
Commit
76b0664
·
verified ·
1 Parent(s): 0f260f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -114
app.py CHANGED
@@ -1,125 +1,41 @@
1
- # app.py
2
  import gradio as gr
3
- import torch
4
  from PIL import Image
5
- import json
6
- import os
7
- from tokenizers import ByteLevelBPETokenizer # Changed from Tokenizer
8
- from torchvision import transforms
9
 
10
- # ============================================================
11
- # 3. Model (Tiny CNN + Transformer Decoder) - Re-define model classes
12
- # ============================================================
13
- class CNNEncoder(torch.nn.Module):
14
- def __init__(self, embed_dim=128):
15
- super().__init__()
16
- self.cnn = torch.nn.Sequential(
17
- torch.nn.Conv2d(3, 32, 3, 2, 1), torch.nn.ReLU(),
18
- torch.nn.Conv2d(32, 64, 3, 2, 1), torch.nn.ReLU(),
19
- torch.nn.Conv2d(64, 128, 3, 2, 1), torch.nn.ReLU(),
20
- torch.nn.AdaptiveAvgPool2d((1,1))
21
- )
22
- self.fc = torch.nn.Linear(128, embed_dim)
23
- def forward(self, x):
24
- x = self.cnn(x)
25
- x = x.view(x.size(0), -1)
26
- return self.fc(x)
27
 
28
- class TransformerDecoder(torch.nn.Module):
29
- def __init__(self, vocab_size, embed_dim=128, nhead=4, num_layers=2, max_len=40):
30
- super().__init__()
31
- self.embed = torch.nn.Embedding(vocab_size, embed_dim)
32
- decoder_layer = torch.nn.TransformerDecoderLayer(d_model=embed_dim, nhead=nhead)
33
- self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
34
- self.fc_out = torch.nn.Linear(embed_dim, vocab_size)
35
- self.pos_embed = torch.nn.Embedding(max_len, embed_dim)
36
-
37
- def forward(self, tgt, memory):
38
- positions = torch.arange(0, tgt.shape[1], device=tgt.device).unsqueeze(0)
39
- tgt_emb = self.embed(tgt) + self.pos_embed(positions)
40
- memory = memory.unsqueeze(0)
41
- out = self.decoder(tgt_emb.transpose(0,1), memory)
42
- return self.fc_out(out.transpose(0,1))
43
-
44
- class ImageCaptionModel(torch.nn.Module):
45
- def __init__(self, vocab_size, embed_dim=128):
46
- super().__init__()
47
- self.encoder = CNNEncoder(embed_dim)
48
- self.decoder = TransformerDecoder(vocab_size, embed_dim)
49
- def forward(self, images, captions):
50
- feats = self.encoder(images)
51
- return self.decoder(captions, feats)
52
-
53
- # ============================================================
54
- # Load the tokenizer and model manually
55
- # ============================================================
56
-
57
- # Load config
58
- with open("hackergeek/radiology-image-captioning/config.json", "r") as f:
59
- config = json.load(f)
60
-
61
- # Load tokenizer - Corrected to use ByteLevelBPETokenizer with both files
62
- tokenizer = ByteLevelBPETokenizer("radiology_caption_model/vocab.json", "radiology_caption_model/merges.txt")
63
-
64
- # Instantiate the model with config parameters
65
- model = ImageCaptionModel(
66
- vocab_size=config["vocab_size"],
67
- embed_dim=config["embed_dim"]
68
- )
69
-
70
- # Load the model weights
71
- model.load_state_dict(torch.load("radiology_caption_model/pytorch_model.bin", map_location=torch.device('cpu')))
72
- model.eval() # Set model to evaluation mode
73
-
74
- # Define image transformations
75
- image_size = 128 # Must match training image size
76
- img_transforms = transforms.Compose([
77
- transforms.Resize((image_size, image_size)),
78
- transforms.ToTensor(),
79
- transforms.Normalize(mean=[0.485,0.456,0.406],
80
- std=[0.229,0.224,0.225]),
81
- ])
82
-
83
- # Function to generate caption
84
  def generate_caption(image):
85
- # Preprocess image
86
- img_tensor = img_transforms(image).unsqueeze(0) # Add batch dimension
87
-
88
- # Generate caption
89
- with torch.no_grad():
90
- # Get image features
91
- image_features = model.encoder(img_tensor)
92
-
93
- # Start caption generation with BOS token
94
- # (We assume BOS token ID is 2 from tokenizer training in cell 1)
95
- # (Padding token ID is 0)
96
- caption_tokens = [tokenizer.token_to_id("[BOS]")]
97
- max_len = config["max_len"] if "max_len" in config else 40 # Use max_len from config, fallback to 40
98
-
99
- for _ in range(max_len - 1): # -1 because BOS is already there
100
- input_tokens = torch.tensor(caption_tokens).unsqueeze(0) # Add batch dimension
101
- output = model.decoder(input_tokens, image_features)
102
- last_token_logits = output[0, -1, :]
103
- predicted_token_id = torch.argmax(last_token_logits).item()
104
-
105
- caption_tokens.append(predicted_token_id)
106
-
107
- # Stop if EOS token is generated
108
- if predicted_token_id == tokenizer.token_to_id("[EOS]"):
109
- break
110
-
111
- # Decode the output tokens, excluding BOS and EOS (if present)
112
- decoded_caption = tokenizer.decode(caption_tokens[1:-1] if caption_tokens[-1] == tokenizer.token_to_id("[EOS]") else caption_tokens[1:])
113
- return decoded_caption
114
 
115
  # Create Gradio interface
116
- interface = gr.Interface(
 
 
 
117
  fn=generate_caption,
118
- inputs=gr.Image(type="pil"),
119
- outputs="text",
120
- title="Radiology Image Captioning",
121
- description="Upload a radiology image (X-ray, CT, MRI) to get an AI-generated caption."
 
 
 
122
  )
123
 
124
  if __name__ == "__main__":
125
- interface.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
4
+ import requests
 
 
 
5
 
6
+ # Load model and processor
7
+ model_name = "hackergeek/radiology-image-captioning"
8
+ processor = BlipProcessor.from_pretrained(model_name)
9
+ model = BlipForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def generate_caption(image):
12
+ """
13
+ Generates a radiology caption for a given image
14
+ """
15
+ if isinstance(image, str): # if image is a URL
16
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
17
+ elif isinstance(image, Image.Image):
18
+ image = image.convert("RGB")
19
+
20
+ inputs = processor(images=image, return_tensors="pt")
21
+ out = model.generate(**inputs)
22
+ caption = processor.decode(out[0], skip_special_tokens=True)
23
+ return caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Create Gradio interface
26
+ title = "Radiology Image Captioning"
27
+ description = "Upload a radiology image (X-ray, CT, MRI) and get an automatic caption generated by the `hackergeek/radiology-image-captioning` model."
28
+
29
+ iface = gr.Interface(
30
  fn=generate_caption,
31
+ inputs=gr.Image(type="pil", label="Upload Radiology Image"),
32
+ outputs=gr.Textbox(label="Generated Caption"),
33
+ title=title,
34
+ description=description,
35
+ examples=[
36
+ ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/medical_xray.png"]
37
+ ]
38
  )
39
 
40
  if __name__ == "__main__":
41
+ iface.launch()