import gradio as gr import torch from PIL import Image from transformers import Blip2Processor, Blip2ForConditionalGeneration def load_model(): # Load model and processor model_id = "fathindifa/food-caption-blip2" processor = Blip2Processor.from_pretrained(model_id) model = Blip2ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float32, device_map="auto" # Will use GPU if available, otherwise CPU ) model.eval() return processor, model processor, model = load_model() def generate_caption(image, prompt="", max_length=32, temperature=0.7): """Generate caption for the input image.""" if image is None: return "Please upload an image." # Ensure image is in RGB image = Image.fromarray(image).convert("RGB") # Process image inputs = processor(images=image, text=prompt, return_tensors="pt") # Move inputs to same device as model inputs = {k: v.to(model.device) for k, v in inputs.items()} # Generate caption with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_length, num_beams=5, temperature=temperature ) # Decode and return caption caption = processor.batch_decode(outputs, skip_special_tokens=True)[0] return caption # Create Gradio interface iface = gr.Interface( fn=generate_caption, inputs=[ gr.Image(label="Upload Food Image"), gr.Textbox(label="Optional Prompt", placeholder="Enter prompt to guide the caption (optional)", value=""), gr.Slider(minimum=10, maximum=50, value=32, step=1, label="Maximum Caption Length"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") ], outputs=gr.Textbox(label="Generated Caption"), title="🍜 Food Image Captioning", description="Upload a food image and get an automatically generated caption! This demo uses a fine-tuned BLIP2 model specifically trained for food image captioning.", article=""" ### Tips: - Try different prompts to guide the caption generation - Adjust temperature for more/less creative captions - Increase max length for longer descriptions """, flagging_mode="never", cache_examples=False ) # Launch the interface if __name__ == "__main__": iface.launch()