radio_cap / app.py
hackergeek98's picture
Update app.py
46dbcaa verified
import gradio as gr
from PIL import Image
import torch
from torchvision import transforms
from transformers import BlipForConditionalGeneration, AutoTokenizer
# Load model and tokenizer
model_name = "hackergeek/radiology-image-captioning"
model = BlipForConditionalGeneration.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Manual preprocessing
preprocess = transforms.Compose([
transforms.Resize((384, 384)), # BLIP models usually expect 384x384
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def generate_caption(image):
"""
Generate radiology caption for a PIL image.
"""
if not isinstance(image, Image.Image):
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
pixel_values = preprocess(image).unsqueeze(0) # add batch dimension
with torch.no_grad():
outputs = model.generate(pixel_values=pixel_values)
caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
return caption
# Gradio Interface
title = "Radiology Image Captioning"
description = (
"Upload a radiology image (X-ray, CT, MRI) and get an automatic caption "
"generated by the `hackergeek/radiology-image-captioning` model."
)
iface = gr.Interface(
fn=generate_caption,
inputs=gr.Image(type="pil", label="Upload Radiology Image"),
outputs=gr.Textbox(label="Generated Caption"),
title=title,
description=description,
examples=[
["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/medical_xray.png"]
]
)
if __name__ == "__main__":
iface.launch()