|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from typing import Dict, Any |
|
|
from PIL import Image |
|
|
import open_clip |
|
|
from transformers import ( |
|
|
BioGptTokenizer, |
|
|
BioGptForCausalLM, |
|
|
AutoTokenizer, |
|
|
AutoModelForSeq2SeqLM |
|
|
) |
|
|
import gradio as gr |
|
|
|
|
|
from IndicTransToolkit import IndicProcessor |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "Robinhood135/biogptm1" |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
BIOMEDCLIP_MODEL_NAME = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224' |
|
|
CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
|
CLIP_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
PREFIX_LENGTH = 10 |
|
|
PROMPT_TEXT = "You are a Radiologist.The chest image findings are:" |
|
|
|
|
|
|
|
|
BEST_STRATEGY_PARAMS = { |
|
|
"num_beams": 4, |
|
|
"do_sample": False, |
|
|
"repetition_penalty": 1.2, |
|
|
"max_new_tokens": 100, |
|
|
"min_new_tokens": 10, |
|
|
} |
|
|
|
|
|
|
|
|
def freeze_module(module: nn.Module): |
|
|
for param in module.parameters(): param.requires_grad = False |
|
|
|
|
|
class BiomedCLIPBioGPTGenerator(nn.Module): |
|
|
def __init__(self, tokenizer, model_name=BIOMEDCLIP_MODEL_NAME, prefix_length=PREFIX_LENGTH): |
|
|
super().__init__() |
|
|
self.tokenizer = tokenizer |
|
|
self.prefix_length = prefix_length |
|
|
self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name) |
|
|
|
|
|
self.image_encoder = self.clip_model.visual if hasattr(self.clip_model, 'visual') else self.clip_model.encode_image |
|
|
freeze_module(self.image_encoder) |
|
|
|
|
|
with torch.no_grad(): |
|
|
dummy_features = self.image_encoder(torch.randn(1, 3, 224, 224)) |
|
|
if isinstance(dummy_features, tuple): dummy_features = dummy_features[0] |
|
|
self.embed_dim = dummy_features.shape[-1] |
|
|
|
|
|
config = BioGptForCausalLM.from_pretrained('microsoft/biogpt').config |
|
|
self.biogpt = BioGptForCausalLM.from_pretrained('microsoft/biogpt', config=config) |
|
|
self.biogpt.resize_token_embeddings(len(self.tokenizer)) |
|
|
self.gpt_hidden_dim = self.biogpt.config.hidden_size |
|
|
self.biogpt.config.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
|
|
self.projection_head = nn.Sequential( |
|
|
nn.Linear(self.embed_dim, self.prefix_length * self.gpt_hidden_dim), |
|
|
nn.Tanh(), |
|
|
nn.Linear(self.prefix_length * self.gpt_hidden_dim, self.prefix_length * self.gpt_hidden_dim) |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_prefix_embeddings(self, images): |
|
|
clip_features = self.image_encoder(images).float() |
|
|
prefix_embeds = self.projection_head(clip_features) |
|
|
return prefix_embeds.view(-1, self.prefix_length, self.gpt_hidden_dim) |
|
|
|
|
|
def get_text_embeddings(self, input_ids): |
|
|
return self.biogpt.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_report(model, pil_image: Image.Image, method_params: Dict[str, Any]): |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD) |
|
|
]) |
|
|
image_tensor = transform(pil_image.convert('RGB')).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
prefix_embeds = model.get_prefix_embeddings(image_tensor) |
|
|
|
|
|
|
|
|
prompt_data = model.tokenizer(PROMPT_TEXT, return_tensors="pt").to(device) |
|
|
prompt_embeds = model.get_text_embeddings(prompt_data["input_ids"]) |
|
|
combined_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) |
|
|
|
|
|
prefix_att_mask = torch.ones(prefix_embeds.shape[:2], dtype=torch.long, device=device) |
|
|
combined_att_mask = torch.cat([prefix_att_mask, prompt_data["attention_mask"]], dim=1) |
|
|
|
|
|
|
|
|
generation_args = { |
|
|
"inputs_embeds": combined_embeds, |
|
|
"attention_mask": combined_att_mask, |
|
|
"pad_token_id": model.tokenizer.pad_token_id, |
|
|
"eos_token_id": model.tokenizer.eos_token_id, |
|
|
"use_cache": True, |
|
|
} |
|
|
generation_args.update(method_params) |
|
|
|
|
|
|
|
|
generated_ids = model.biogpt.generate(**generation_args) |
|
|
|
|
|
|
|
|
full_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if full_text.startswith(PROMPT_TEXT): |
|
|
text = full_text[len(PROMPT_TEXT):].strip() |
|
|
else: |
|
|
text = full_text |
|
|
|
|
|
return text if text.strip() else "[BLANK/FAILED GENERATION]" |
|
|
|
|
|
|
|
|
|
|
|
def load_trained_generator(): |
|
|
print(f"Loading Report Generator model from {REPO_ID}...") |
|
|
|
|
|
|
|
|
try: |
|
|
clip_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biomedclipp.pth") |
|
|
gpt_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biogptt.pth") |
|
|
proj_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="projectorr.pth") |
|
|
except Exception as e: |
|
|
raise FileNotFoundError(f"Failed to download one or more checkpoint files from {REPO_ID}. Error: {e}") |
|
|
|
|
|
|
|
|
base_tokenizer = BioGptTokenizer.from_pretrained('microsoft/biogpt') |
|
|
if base_tokenizer.pad_token is None: |
|
|
base_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
|
|
|
|
|
|
|
model = BiomedCLIPBioGPTGenerator(base_tokenizer).to(device) |
|
|
|
|
|
|
|
|
clip_checkpoint = torch.load(clip_ckpt_path, map_location=device) |
|
|
state_dict = clip_checkpoint.get('model_state_dict', clip_checkpoint.get('state_dict', clip_checkpoint)) |
|
|
|
|
|
visual_state = {k.replace('model.visual.', '').replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k} |
|
|
model.image_encoder.load_state_dict(visual_state, strict=False) |
|
|
|
|
|
|
|
|
model.biogpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device)) |
|
|
model.projection_head.load_state_dict(torch.load(proj_ckpt_path, map_location=device)) |
|
|
|
|
|
model.eval() |
|
|
print("✅ Report Generator loaded successfully.") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def load_translator(): |
|
|
|
|
|
print("Loading Translation model (IndicTrans2)...") |
|
|
try: |
|
|
|
|
|
ip = IndicProcessor(inference=True) |
|
|
model_name = "ai4bharat/indictrans2-en-indic-dist-200M" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device) |
|
|
print("✅ Translator loaded successfully.") |
|
|
return ip, tokenizer, model |
|
|
except Exception as e: |
|
|
print(f"Error loading translation model: {e}") |
|
|
|
|
|
return None, None, None |
|
|
|
|
|
|
|
|
GENERATOR_MODEL = load_trained_generator() |
|
|
IP, TRANS_TOKENIZER, TRANS_MODEL = load_translator() |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def translate_report(english_text: str, target_lang: str = "hin_Deva") -> str: |
|
|
if TRANS_MODEL is None or not english_text: |
|
|
return "[Translation Model Not Available or No Text to Translate]" |
|
|
|
|
|
|
|
|
batch = IP.preprocess_batch([english_text], src_lang="eng_Latn", tgt_lang=target_lang, visualize=False) |
|
|
batch = TRANS_TOKENIZER(batch, padding="longest", truncation=True, max_length=256, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
outputs = TRANS_MODEL.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256, use_cache=False) |
|
|
|
|
|
|
|
|
outputs = TRANS_TOKENIZER.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
|
translated_text = IP.postprocess_batch(outputs, lang=target_lang)[0] |
|
|
|
|
|
return translated_text |
|
|
|
|
|
|
|
|
|
|
|
def inference_wrapper(input_image: Image.Image): |
|
|
if input_image is None: |
|
|
return "Please upload a chest X-ray image.", "[No English Report]" |
|
|
|
|
|
|
|
|
try: |
|
|
raw_english_report = generate_report(GENERATOR_MODEL, input_image, BEST_STRATEGY_PARAMS) |
|
|
except Exception as e: |
|
|
raw_english_report = f"An error occurred during generation: {e}" |
|
|
return raw_english_report, "[Translation Skipped]" |
|
|
|
|
|
|
|
|
try: |
|
|
hindi_report = translate_report(raw_english_report, target_lang="hin_Deva") |
|
|
except Exception as e: |
|
|
hindi_report = f"[Translation failed: {e}]" |
|
|
|
|
|
return raw_english_report, hindi_report |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
EXAMPLE_FILENAMES = [ |
|
|
"001c3589-7aed3964-f06ba8d5-03882592-d77f222c.jpg", |
|
|
"004438db-4a5d6ab3-acc6c408-5dce0934-7d30b269.jpg", |
|
|
"0006f2ea-d44c6b5e-aeea6fd2-a974657c-90a39211.jpg", |
|
|
"0008ba07-4e43d6f4-fc692a96-c18a27a8-10eea0cd.jpg", |
|
|
"001526e1-0d0b8a2d-87e74f7e-72646210-c635fee4.jpg", |
|
|
"00438e51-4f75714b-943c8edd-6740491f-f8307602.jpg", |
|
|
"001c78df-8ce750bd-c100a8e0-2874ea0e-09cdbd4e.jpg", |
|
|
"000b9235-69b5b7e2-1ec32996-50f79b97-46f939cf.jpg", |
|
|
|
|
|
"001198e2-a2adcc23-7253eb78-0dcb5eaa-b10ed183.jpg", |
|
|
"0003fc7c-3dfce751-9ff36dc3-8fa4f6d9-0515ce50.jpg", |
|
|
"0018ff6b-8ad1196f-823030d0-1141b667-2a1a117a.jpg", |
|
|
"00068d26-8d583659-af7de1da-fc6c0476-d94aada1.jpg", |
|
|
"00196af8-50d17b31-b1b5a7be-da90b7e6-fd3a8004.jpg", |
|
|
"004017bd-6506697c-3ead0e70-548114b7-2af62447.jpg", |
|
|
"00059571-ade80b6c-7931ddb8-b486c6c1-1e543b22.jpg", |
|
|
"00419c98-6f4860a1-3dee986d-8e2ceadc-d2fd30ae.jpg", |
|
|
"000ffbff-3d93bcef-da8b17cd-fbcede53-51728df9.jpg", |
|
|
"0016e39b-d0cad5f2-eecb7ae8-4db8b8f2-0b366f1a.jpg", |
|
|
"00469c3d-4ebf8374-055428f7-d798daca-3e37d354.jpg", |
|
|
"0013ac79-5eea664c-7ef52c71-7e5a25f3-013715fc.jpg" |
|
|
] |
|
|
|
|
|
|
|
|
examples = [[os.path.join("examples", f)] for f in EXAMPLE_FILENAMES] |
|
|
|
|
|
|
|
|
input_image = gr.Image(type="pil", label="Upload Chest X-ray Image") |
|
|
output_en = gr.Textbox(label="Generated Radiology Report (English)", lines=5) |
|
|
output_hi = gr.Textbox(label="Generated Radiology Report (Hindi/हिन्दी)", lines=5) |
|
|
|
|
|
|
|
|
app = gr.Interface( |
|
|
fn=inference_wrapper, |
|
|
inputs=input_image, |
|
|
outputs=[output_en, output_hi], |
|
|
title="🔬 Cascading BiomedCLIP-BioGPT & IndicTrans2 Report Generator", |
|
|
description="Upload a chest X-ray image to generate a radiology finding in English and automatically translate it to Hindi.", |
|
|
|
|
|
examples=examples, |
|
|
cache_examples=False |
|
|
|
|
|
) |
|
|
|
|
|
print("\nStarting Gradio interface...") |
|
|
app.launch() |