import os import torch import torch.nn as nn from torchvision import transforms from typing import Dict, Any from PIL import Image import open_clip from transformers import ( BioGptTokenizer, BioGptForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM ) import gradio as gr # NOTE: Ensure this library is installed on the Hugging Face Space from IndicTransToolkit import IndicProcessor from huggingface_hub import hf_hub_download # New import for HF deployment # --- 1. CONFIGURATION (Stage 1: Report Generation) --- # NOTE: Update this REPO_ID to the actual Hugging Face repository where you upload your .pth files! REPO_ID = "Robinhood135/biogptm1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- MODEL/DECODING PARAMS --- BIOMEDCLIP_MODEL_NAME = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224' CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) CLIP_STD = (0.26862954, 0.26130258, 0.27577711) PREFIX_LENGTH = 10 PROMPT_TEXT = "You are a Radiologist.The chest image findings are:" # --- BEST DECODING STRATEGY (Beam Search) --- BEST_STRATEGY_PARAMS = { "num_beams": 4, "do_sample": False, "repetition_penalty": 1.2, "max_new_tokens": 100, "min_new_tokens": 10, } # --- 2. MODEL CLASS (Stage 1) - Kept the same --- def freeze_module(module: nn.Module): for param in module.parameters(): param.requires_grad = False class BiomedCLIPBioGPTGenerator(nn.Module): def __init__(self, tokenizer, model_name=BIOMEDCLIP_MODEL_NAME, prefix_length=PREFIX_LENGTH): super().__init__() self.tokenizer = tokenizer self.prefix_length = prefix_length self.clip_model, _, _ = open_clip.create_model_and_transforms(model_name) # Handle cases where image encoder is visual or a direct method self.image_encoder = self.clip_model.visual if hasattr(self.clip_model, 'visual') else self.clip_model.encode_image freeze_module(self.image_encoder) with torch.no_grad(): dummy_features = self.image_encoder(torch.randn(1, 3, 224, 224)) if isinstance(dummy_features, tuple): dummy_features = dummy_features[0] self.embed_dim = dummy_features.shape[-1] config = BioGptForCausalLM.from_pretrained('microsoft/biogpt').config self.biogpt = BioGptForCausalLM.from_pretrained('microsoft/biogpt', config=config) self.biogpt.resize_token_embeddings(len(self.tokenizer)) self.gpt_hidden_dim = self.biogpt.config.hidden_size self.biogpt.config.pad_token_id = self.tokenizer.pad_token_id self.projection_head = nn.Sequential( nn.Linear(self.embed_dim, self.prefix_length * self.gpt_hidden_dim), nn.Tanh(), nn.Linear(self.prefix_length * self.gpt_hidden_dim, self.prefix_length * self.gpt_hidden_dim) ) @torch.no_grad() def get_prefix_embeddings(self, images): clip_features = self.image_encoder(images).float() prefix_embeds = self.projection_head(clip_features) return prefix_embeds.view(-1, self.prefix_length, self.gpt_hidden_dim) def get_text_embeddings(self, input_ids): return self.biogpt.get_input_embeddings()(input_ids) # --- 3. INFERENCE FUNCTION (Stage 1) - Kept the same --- @torch.no_grad() def generate_report(model, pil_image: Image.Image, method_params: Dict[str, Any]): model.eval() # 3.1 Apply image transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD) ]) image_tensor = transform(pil_image.convert('RGB')).unsqueeze(0).to(device) # 3.2 Get prefix embeddings prefix_embeds = model.get_prefix_embeddings(image_tensor) # 3.3 Encode prompt text prompt_data = model.tokenizer(PROMPT_TEXT, return_tensors="pt").to(device) prompt_embeds = model.get_text_embeddings(prompt_data["input_ids"]) combined_embeds = torch.cat([prefix_embeds, prompt_embeds], dim=1) prefix_att_mask = torch.ones(prefix_embeds.shape[:2], dtype=torch.long, device=device) combined_att_mask = torch.cat([prefix_att_mask, prompt_data["attention_mask"]], dim=1) # 3.4 Generation parameters generation_args = { "inputs_embeds": combined_embeds, "attention_mask": combined_att_mask, "pad_token_id": model.tokenizer.pad_token_id, "eos_token_id": model.tokenizer.eos_token_id, "use_cache": True, } generation_args.update(method_params) # 3.5 Generate generated_ids = model.biogpt.generate(**generation_args) # 3.6 Decode and clean full_text = model.tokenizer.decode(generated_ids[0], skip_special_tokens=True) if full_text.startswith(PROMPT_TEXT): text = full_text[len(PROMPT_TEXT):].strip() else: text = full_text return text if text.strip() else "[BLANK/FAILED GENERATION]" # --- 4. MODEL LOADING (Stage 1) - MODIFIED FOR HF HUB --- def load_trained_generator(): print(f"Loading Report Generator model from {REPO_ID}...") # Load from Hugging Face Hub try: clip_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biomedclipp.pth") gpt_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="biogptt.pth") proj_ckpt_path = hf_hub_download(repo_id=REPO_ID, filename="projectorr.pth") except Exception as e: raise FileNotFoundError(f"Failed to download one or more checkpoint files from {REPO_ID}. Error: {e}") # Initialize tokenizer base_tokenizer = BioGptTokenizer.from_pretrained('microsoft/biogpt') if base_tokenizer.pad_token is None: base_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) # Initialize model model = BiomedCLIPBioGPTGenerator(base_tokenizer).to(device) # Load CLIP encoder clip_checkpoint = torch.load(clip_ckpt_path, map_location=device) state_dict = clip_checkpoint.get('model_state_dict', clip_checkpoint.get('state_dict', clip_checkpoint)) # Filter state dict for the visual encoder and clean keys visual_state = {k.replace('model.visual.', '').replace('visual.', ''): v for k, v in state_dict.items() if 'visual' in k} model.image_encoder.load_state_dict(visual_state, strict=False) # Load trained BioGPT and Projection weights model.biogpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device)) model.projection_head.load_state_dict(torch.load(proj_ckpt_path, map_location=device)) model.eval() print("✅ Report Generator loaded successfully.") return model # --- 5. MODEL LOADING (Stage 2: Translation) - Kept the same --- def load_translator(): # IndicTrans2 models are typically loaded directly from their HF repos (ai4bharat/...) print("Loading Translation model (IndicTrans2)...") try: # IndicTransToolkit library is assumed to be installed ip = IndicProcessor(inference=True) model_name = "ai4bharat/indictrans2-en-indic-dist-200M" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # Note: If memory is an issue on the Space, you might need to use a smaller model or lower precision. model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device) print("✅ Translator loaded successfully.") return ip, tokenizer, model except Exception as e: print(f"Error loading translation model: {e}") # Return dummy values if loading fails to prevent crash return None, None, None # Load models globally GENERATOR_MODEL = load_trained_generator() IP, TRANS_TOKENIZER, TRANS_MODEL = load_translator() # --- 6. TRANSLATION FUNCTION (Stage 2) - Kept the same --- @torch.no_grad() def translate_report(english_text: str, target_lang: str = "hin_Deva") -> str: if TRANS_MODEL is None or not english_text: return "[Translation Model Not Available or No Text to Translate]" # 6.1 Preprocessing batch = IP.preprocess_batch([english_text], src_lang="eng_Latn", tgt_lang=target_lang, visualize=False) batch = TRANS_TOKENIZER(batch, padding="longest", truncation=True, max_length=256, return_tensors="pt").to(device) # 6.2 Generation outputs = TRANS_MODEL.generate(**batch, num_beams=5, num_return_sequences=1, max_length=256, use_cache=False) # 6.3 Postprocessing outputs = TRANS_TOKENIZER.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True) translated_text = IP.postprocess_batch(outputs, lang=target_lang)[0] return translated_text # --- 7. GRADIO WRAPPER FUNCTION (Simplified) - Kept the same --- def inference_wrapper(input_image: Image.Image): if input_image is None: return "Please upload a chest X-ray image.", "[No English Report]" # STAGE 1: GENERATE RAW ENGLISH REPORT try: raw_english_report = generate_report(GENERATOR_MODEL, input_image, BEST_STRATEGY_PARAMS) except Exception as e: raw_english_report = f"An error occurred during generation: {e}" return raw_english_report, "[Translation Skipped]" # STAGE 2: TRANSLATE RAW ENGLISH REPORT try: hindi_report = translate_report(raw_english_report, target_lang="hin_Deva") except Exception as e: hindi_report = f"[Translation failed: {e}]" return raw_english_report, hindi_report # --- 8. GRADIO INTERFACE SETUP --- if __name__ == "__main__": # Define example image filenames EXAMPLE_FILENAMES = [ "001c3589-7aed3964-f06ba8d5-03882592-d77f222c.jpg", "004438db-4a5d6ab3-acc6c408-5dce0934-7d30b269.jpg", "0006f2ea-d44c6b5e-aeea6fd2-a974657c-90a39211.jpg", "0008ba07-4e43d6f4-fc692a96-c18a27a8-10eea0cd.jpg", "001526e1-0d0b8a2d-87e74f7e-72646210-c635fee4.jpg", "00438e51-4f75714b-943c8edd-6740491f-f8307602.jpg", "001c78df-8ce750bd-c100a8e0-2874ea0e-09cdbd4e.jpg", "000b9235-69b5b7e2-1ec32996-50f79b97-46f939cf.jpg", # "0041603e-059f400f-c509c746-0da5c413-ee889ec1.jpg", "001198e2-a2adcc23-7253eb78-0dcb5eaa-b10ed183.jpg", "0003fc7c-3dfce751-9ff36dc3-8fa4f6d9-0515ce50.jpg", "0018ff6b-8ad1196f-823030d0-1141b667-2a1a117a.jpg", "00068d26-8d583659-af7de1da-fc6c0476-d94aada1.jpg", "00196af8-50d17b31-b1b5a7be-da90b7e6-fd3a8004.jpg", "004017bd-6506697c-3ead0e70-548114b7-2af62447.jpg", "00059571-ade80b6c-7931ddb8-b486c6c1-1e543b22.jpg", "00419c98-6f4860a1-3dee986d-8e2ceadc-d2fd30ae.jpg", "000ffbff-3d93bcef-da8b17cd-fbcede53-51728df9.jpg", "0016e39b-d0cad5f2-eecb7ae8-4db8b8f2-0b366f1a.jpg", "00469c3d-4ebf8374-055428f7-d798daca-3e37d354.jpg", "0013ac79-5eea664c-7ef52c71-7e5a25f3-013715fc.jpg" ] # Create examples list with only image paths examples = [[os.path.join("examples", f)] for f in EXAMPLE_FILENAMES] # Interface components input_image = gr.Image(type="pil", label="Upload Chest X-ray Image") output_en = gr.Textbox(label="Generated Radiology Report (English)", lines=5) output_hi = gr.Textbox(label="Generated Radiology Report (Hindi/हिन्दी)", lines=5) # Gradio app setup app = gr.Interface( fn=inference_wrapper, inputs=input_image, outputs=[output_en, output_hi], title="🔬 Cascading BiomedCLIP-BioGPT & IndicTrans2 Report Generator", description="Upload a chest X-ray image to generate a radiology finding in English and automatically translate it to Hindi.", # allow_flagging="never", examples=examples, cache_examples=False # cache_examples=True ) print("\nStarting Gradio interface...") app.launch() # Removed share=True for typical Hugging Face Space deployment