import os import uuid import shutil from datetime import datetime from typing import Dict, List, Optional import numpy as np from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from PIL import Image from src.core import process_inpaint from pymongo import MongoClient # Connect to MongoDB (URI must be in HF env variables) MONGO_URI = os.environ.get("MONGO_URI") mongo_client = None mongo_collection = None if MONGO_URI: try: mongo_client = MongoClient(MONGO_URI) mongo_db = mongo_client["object-remove-logs"] mongo_collection = mongo_db["api_logs"] except Exception as e: print("Mongo connection failed:", e) # Directories (use writable space on HF Spaces) BASE_DIR = os.environ.get("DATA_DIR", "/data") if not os.path.isdir(BASE_DIR): # Fallback to /tmp if /data not available BASE_DIR = "/tmp" UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open ENV_TOKEN = os.environ.get("API_TOKEN") app = FastAPI(title="Photo Object Removal API", version="1.0.0") # In-memory stores file_store: Dict[str, Dict[str, str]] = {} logs: List[Dict[str, str]] = [] def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: if not ENV_TOKEN: return if authorization is None or not authorization.lower().startswith("bearer "): raise HTTPException(status_code=401, detail="Unauthorized") token = authorization.split(" ", 1)[1] if token != ENV_TOKEN: raise HTTPException(status_code=403, detail="Forbidden") class InpaintRequest(BaseModel): image_id: str mask_id: str @app.get("/") def root() -> Dict[str, object]: return { "name": "Photo Object Removal API", "status": "ok", "endpoints": { "GET /health": "health check", "POST /upload-image": "form-data: image=file", "POST /upload-mask": "form-data: mask=file", "POST /inpaint": "JSON: {image_id, mask_id}", "POST /inpaint-multipart": "form-data: image=file, mask=file", "GET /download/{filename}": "download result image", "GET /result/{filename}": "view result image in browser", "GET /logs": "recent uploads/results", }, "auth": "set API_TOKEN env var to require Authorization: Bearer (except /health)", } @app.get("/health") def health() -> Dict[str, str]: return {"status": "healthy"} @app.post("/upload-image") def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(image.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(image.file, f) file_store[file_id] = { "type": "image", "filename": image.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": image.filename} @app.post("/upload-mask") def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(mask.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(mask.file, f) file_store[file_id] = { "type": "mask", "filename": mask.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": mask.filename} def _load_rgba_image(path: str) -> Image.Image: img = Image.open(path) return img.convert("RGBA") def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: # Expected by process_inpaint: RGBA where alpha=0 for drawn (to remove), 255 elsewhere if img.mode != "RGBA": # If no alpha, treat non-black/white>0 as masked areas gray = img.convert("L") arr = np.array(gray) alpha = np.where(arr > 0, 0, 255).astype(np.uint8) rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) rgba[:, :, 3] = alpha return rgba return np.array(img) @app.post("/inpaint") def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) mask_rgba = _load_rgba_mask_from_image(mask_img) result = process_inpaint(np.array(img_rgba), mask_rgba) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) # Mongo logging if mongo_collection: mongo_collection.insert_one({ "endpoint": "/inpaint", "result": result_name, "timestamp": datetime.utcnow().isoformat() }) return {"result": result_name} @app.post("/inpaint-url") def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: """Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA mask_rgba = _load_rgba_mask_from_image(mask_img) result = process_inpaint(np.array(img_rgba), mask_rgba) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url = str(request.url_for("download_file", filename=result_name)) logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) return {"result": result_name, "url": url} @app.post("/inpaint-multipart") def inpaint_multipart( image: UploadFile = File(...), mask: UploadFile = File(...), request: Request = None, _: None = Depends(bearer_auth), ) -> Dict[str, str]: img = Image.open(image.file).convert("RGBA") m = Image.open(mask.file) mask_rgba = _load_rgba_mask_from_image(m) result = process_inpaint(np.array(img), mask_rgba) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} if url: entry["url"] = url logs.append(entry) # Mongo logging if mongo_collection: mongo_collection.insert_one({ "endpoint": "/inpaint-multipart", "result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat() }) resp: Dict[str, str] = {"result": result_name} if url: resp["url"] = url return resp @app.get("/download/{filename}") def download_file(filename: str): path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path) @app.get("/result/{filename}") def view_result(filename: str): """View result image directly in browser (same as download but with proper content-type for viewing)""" path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path, media_type="image/png") @app.get("/logs") def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: return JSONResponse(content=logs)