hackergeek98 commited on
Commit
46dbcaa
·
verified ·
1 Parent(s): 369cc6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,30 +1,43 @@
1
  import gradio as gr
2
- from transformers import BlipProcessor, BlipForConditionalGeneration
3
  from PIL import Image
4
- import requests
 
 
5
 
6
- # Load model and processor
7
  model_name = "hackergeek/radiology-image-captioning"
8
- processor = BlipProcessor.from_pretrained(model_name)
9
  model = BlipForConditionalGeneration.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
10
 
11
  def generate_caption(image):
12
  """
13
- Generates a radiology caption for a given image
14
  """
15
- if isinstance(image, str): # if image is a URL
16
- image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
17
- elif isinstance(image, Image.Image):
18
  image = image.convert("RGB")
19
-
20
- inputs = processor(images=image, return_tensors="pt")
21
- out = model.generate(**inputs)
22
- caption = processor.decode(out[0], skip_special_tokens=True)
 
23
  return caption
24
 
25
- # Create Gradio interface
26
  title = "Radiology Image Captioning"
27
- description = "Upload a radiology image (X-ray, CT, MRI) and get an automatic caption generated by the `hackergeek/radiology-image-captioning` model."
 
 
 
28
 
29
  iface = gr.Interface(
30
  fn=generate_caption,
 
1
  import gradio as gr
 
2
  from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import BlipForConditionalGeneration, AutoTokenizer
6
 
7
+ # Load model and tokenizer
8
  model_name = "hackergeek/radiology-image-captioning"
 
9
  model = BlipForConditionalGeneration.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ # Manual preprocessing
13
+ preprocess = transforms.Compose([
14
+ transforms.Resize((384, 384)), # BLIP models usually expect 384x384
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
17
+ std=[0.229, 0.224, 0.225]),
18
+ ])
19
 
20
  def generate_caption(image):
21
  """
22
+ Generate radiology caption for a PIL image.
23
  """
24
+ if not isinstance(image, Image.Image):
25
+ image = Image.open(image).convert("RGB")
26
+ else:
27
  image = image.convert("RGB")
28
+
29
+ pixel_values = preprocess(image).unsqueeze(0) # add batch dimension
30
+ with torch.no_grad():
31
+ outputs = model.generate(pixel_values=pixel_values)
32
+ caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
  return caption
34
 
35
+ # Gradio Interface
36
  title = "Radiology Image Captioning"
37
+ description = (
38
+ "Upload a radiology image (X-ray, CT, MRI) and get an automatic caption "
39
+ "generated by the `hackergeek/radiology-image-captioning` model."
40
+ )
41
 
42
  iface = gr.Interface(
43
  fn=generate_caption,