import os import uuid import base64 from io import BytesIO from typing import Optional from datetime import datetime from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from PIL import Image # Try to import Qwen Image Edit model try: import transformers_gradio import gradio as gr GRADIO_AVAILABLE = True except ImportError: GRADIO_AVAILABLE = False print("Warning: gradio/transformers_gradio not available. Using mock mode.") app = FastAPI( title="Nano Banana Image Edit API", description="API for Qwen Image Edit model - Upload images and edit them with prompts", version="1.0.0" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # In-memory storage for tasks (use Redis or database in production) tasks = {} # Model initialization - using Gradio interface gradio_demo = None gradio_fn = None def load_model(): """Load the Qwen Image Edit model using Gradio""" global gradio_demo, gradio_fn if not GRADIO_AVAILABLE: return False try: print("Loading Qwen/Qwen-Image-Edit model via Gradio...") gradio_demo = gr.load(name="Qwen/Qwen-Image-Edit", src=transformers_gradio.registry) # Get the main function from the demo # The function signature depends on the model, typically (image, prompt) -> image if hasattr(gradio_demo, 'fn'): gradio_fn = gradio_demo.fn elif hasattr(gradio_demo, 'blocks') and gradio_demo.blocks: # Try to find the function from blocks for block in gradio_demo.blocks.values(): if hasattr(block, 'fn') and callable(block.fn): gradio_fn = block.fn break print("Model loaded successfully") return True except Exception as e: print(f"Error loading model: {e}") return False # Response models class UploadResponse(BaseModel): image_id: str message: str timestamp: str class EditRequest(BaseModel): image_id: str prompt: str class EditResponse(BaseModel): task_id: str status: str message: str timestamp: str class ResultResponse(BaseModel): task_id: str status: str result_image_id: Optional[str] = None result_image_url: Optional[str] = None error: Optional[str] = None timestamp: str class ErrorResponse(BaseModel): error: str detail: Optional[str] = None @app.on_event("startup") async def startup_event(): """Initialize model on startup""" if GRADIO_AVAILABLE: load_model() @app.get("/") async def root(): """Root endpoint""" return { "message": "Nano Banana Image Edit API", "version": "1.0.0", "endpoints": { "upload": "/upload", "edit": "/edit", "result": "/result/{task_id}", "health": "/health" } } @app.get("/health") async def health(): """Health check endpoint""" return { "status": "healthy", "model_loaded": gradio_fn is not None if GRADIO_AVAILABLE else False, "model_available": GRADIO_AVAILABLE } @app.post("/upload", response_model=UploadResponse) async def upload_image(file: UploadFile = File(...)): """ Upload an image file Returns: image_id: Unique identifier for the uploaded image """ # Validate file type if not file.content_type or not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") try: # Read image image_data = await file.read() image = Image.open(BytesIO(image_data)) # Validate image image.verify() image = Image.open(BytesIO(image_data)) # Reopen after verify # Generate unique ID image_id = str(uuid.uuid4()) # Store image (in production, use proper storage like S3, Azure Blob, etc.) os.makedirs("uploads", exist_ok=True) image_path = f"uploads/{image_id}.{image.format.lower()}" image.save(image_path) # Store metadata tasks[image_id] = { "type": "image", "path": image_path, "format": image.format, "size": image.size, "uploaded_at": datetime.now().isoformat() } return UploadResponse( image_id=image_id, message="Image uploaded successfully", timestamp=datetime.now().isoformat() ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") @app.post("/edit", response_model=EditResponse) async def edit_image( image_id: str = Form(...), prompt: str = Form(...) ): """ Edit an image using a text prompt Parameters: image_id: ID of the uploaded image prompt: Text prompt describing the desired edit Returns: task_id: Unique identifier for the editing task """ # Validate image exists if image_id not in tasks or tasks[image_id]["type"] != "image": raise HTTPException(status_code=404, detail="Image not found") # Generate task ID task_id = str(uuid.uuid4()) try: # Load image image_path = tasks[image_id]["path"] image = Image.open(image_path) # Process image with model if GRADIO_AVAILABLE and gradio_fn is not None: try: # Call the Gradio function with image and prompt # The function signature should be (image, prompt) -> image result = gradio_fn(image, prompt) # Handle different return types if isinstance(result, tuple): edited_image = result[0] # First element is usually the image elif isinstance(result, dict): edited_image = result.get('image', result.get('output', image)) else: edited_image = result # Ensure it's a PIL Image if not isinstance(edited_image, Image.Image): edited_image = image.copy() # Fallback to original except Exception as e: print(f"Error processing with model: {e}") # Fallback to mock mode edited_image = image.copy() else: # Mock mode - just copy the image edited_image = image.copy() # Save edited image os.makedirs("results", exist_ok=True) result_image_id = str(uuid.uuid4()) result_path = f"results/{result_image_id}.png" edited_image.save(result_path) # Store task tasks[task_id] = { "type": "edit_task", "image_id": image_id, "prompt": prompt, "result_image_id": result_image_id, "result_path": result_path, "status": "completed", "created_at": datetime.now().isoformat() } return EditResponse( task_id=task_id, status="completed", message="Image edited successfully", timestamp=datetime.now().isoformat() ) except Exception as e: # Store failed task tasks[task_id] = { "type": "edit_task", "image_id": image_id, "prompt": prompt, "status": "failed", "error": str(e), "created_at": datetime.now().isoformat() } raise HTTPException(status_code=500, detail=f"Error editing image: {str(e)}") @app.get("/result/{task_id}", response_model=ResultResponse) async def get_result(task_id: str): """ Get the result of an image editing task Parameters: task_id: ID of the editing task Returns: Result information including image URL """ if task_id not in tasks or tasks[task_id]["type"] != "edit_task": raise HTTPException(status_code=404, detail="Task not found") task = tasks[task_id] if task["status"] == "failed": return ResultResponse( task_id=task_id, status="failed", error=task.get("error", "Unknown error"), timestamp=task["created_at"] ) if task["status"] == "completed": result_image_id = task.get("result_image_id") return ResultResponse( task_id=task_id, status="completed", result_image_id=result_image_id, result_image_url=f"/result/image/{result_image_id}", timestamp=task["created_at"] ) return ResultResponse( task_id=task_id, status="processing", timestamp=task["created_at"] ) @app.get("/result/image/{result_image_id}") async def get_result_image(result_image_id: str): """ Get the edited image file Parameters: result_image_id: ID of the result image Returns: Image file """ # Find task with this result_image_id task = None for t in tasks.values(): if t.get("type") == "edit_task" and t.get("result_image_id") == result_image_id: task = t break if not task or "result_path" not in task: raise HTTPException(status_code=404, detail="Result image not found") if not os.path.exists(task["result_path"]): raise HTTPException(status_code=404, detail="Image file not found") return FileResponse( task["result_path"], media_type="image/png", filename=f"edited_{result_image_id}.png" ) @app.get("/result/image/{result_image_id}/base64") async def get_result_image_base64(result_image_id: str): """ Get the edited image as base64 encoded string Parameters: result_image_id: ID of the result image Returns: JSON with base64 encoded image """ # Find task with this result_image_id task = None for t in tasks.values(): if t.get("type") == "edit_task" and t.get("result_image_id") == result_image_id: task = t break if not task or "result_path" not in task: raise HTTPException(status_code=404, detail="Result image not found") if not os.path.exists(task["result_path"]): raise HTTPException(status_code=404, detail="Image file not found") # Read and encode image with open(task["result_path"], "rb") as f: image_data = f.read() base64_data = base64.b64encode(image_data).decode("utf-8") return { "result_image_id": result_image_id, "image_base64": base64_data, "format": "png" } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)