import os import uuid import shutil from datetime import datetime from typing import Dict, List, Optional import numpy as np from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel from PIL import Image import cv2 import logging logging.basicConfig(level=logging.INFO) log = logging.getLogger("api") from src.core import process_inpaint # Directories (use writable space on HF Spaces) BASE_DIR = os.environ.get("DATA_DIR", "/data") if not os.path.isdir(BASE_DIR): # Fallback to /tmp if /data not available BASE_DIR = "/tmp" UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) # Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open ENV_TOKEN = os.environ.get("API_TOKEN") app = FastAPI(title="Photo Object Removal API", version="1.0.0") # In-memory stores file_store: Dict[str, Dict[str, str]] = {} logs: List[Dict[str, str]] = [] def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: if not ENV_TOKEN: return if authorization is None or not authorization.lower().startswith("bearer "): raise HTTPException(status_code=401, detail="Unauthorized") token = authorization.split(" ", 1)[1] if token != ENV_TOKEN: raise HTTPException(status_code=403, detail="Forbidden") class InpaintRequest(BaseModel): image_id: str mask_id: str invert_mask: bool = True # True => selected/painted area is removed passthrough: bool = False # If True, return the original image unchanged @app.get("/") def root() -> Dict[str, object]: return { "name": "Photo Object Removal API", "status": "ok", "endpoints": { "GET /health": "health check", "POST /upload-image": "form-data: image=file", "POST /upload-mask": "form-data: mask=file", "POST /inpaint": "JSON: {image_id, mask_id}", "POST /inpaint-multipart": "form-data: image=file, mask=file", "GET /download/{filename}": "download result image", "GET /result/{filename}": "view result image in browser", "GET /logs": "recent uploads/results", }, "auth": "set API_TOKEN env var to require Authorization: Bearer (except /health)", } @app.get("/health") def health() -> Dict[str, str]: return {"status": "healthy"} @app.post("/upload-image") def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(image.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(image.file, f) file_store[file_id] = { "type": "image", "filename": image.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": image.filename} @app.post("/upload-mask") def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: ext = os.path.splitext(mask.filename)[1] or ".png" file_id = str(uuid.uuid4()) stored_name = f"{file_id}{ext}" stored_path = os.path.join(UPLOAD_DIR, stored_name) with open(stored_path, "wb") as f: shutil.copyfileobj(mask.file, f) file_store[file_id] = { "type": "mask", "filename": mask.filename, "stored_name": stored_name, "path": stored_path, "timestamp": datetime.utcnow().isoformat(), } logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) return {"id": file_id, "filename": mask.filename} def _load_rgba_image(path: str) -> Image.Image: img = Image.open(path) return img.convert("RGBA") def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: """ Convert mask image to RGBA format. Standard convention: white (255) = area to remove, black (0) = area to keep Returns RGBA where alpha=0 means "to remove", alpha=255 means "keep" (This will be inverted in process_inpaint if invert_mask=True) """ if img.mode != "RGBA": # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep gray = img.convert("L") arr = np.array(gray) # White pixels (>128) should have alpha=0 (to remove after inversion) # Black pixels (<=128) should have alpha=255 (to keep after inversion) alpha = np.where(arr > 128, 0, 255).astype(np.uint8) rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) rgba[:, :, 3] = alpha log.info(f"Loaded {img.mode} mask: {int((alpha == 0).sum())} pixels marked for removal (alpha=0)") return rgba # For RGBA: check if alpha channel is meaningful arr = np.array(img) alpha = arr[:, :, 3] # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values if alpha.mean() > 200: # Use RGB to determine mask: white in RGB = remove gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY) alpha = np.where(gray > 128, 0, 255).astype(np.uint8) rgba = arr.copy() rgba[:, :, 3] = alpha log.info(f"Loaded RGBA mask (RGB-based): {int((alpha == 0).sum())} pixels marked for removal (alpha=0)") return rgba # Alpha channel already encodes the mask log.info(f"Loaded RGBA mask (alpha-based): {int((alpha < 128).sum())} pixels marked for removal (alpha<128)") return arr @app.post("/inpaint") def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA mask_rgba = _load_rgba_mask_from_image(mask_img) if req.passthrough: result = np.array(img_rgba.convert("RGB")) else: result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) return {"result": result_name} @app.post("/inpaint-url") def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: """Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": raise HTTPException(status_code=404, detail="image_id not found") if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": raise HTTPException(status_code=404, detail="mask_id not found") img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA mask_rgba = _load_rgba_mask_from_image(mask_img) if req.passthrough: result = np.array(img_rgba.convert("RGB")) else: result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url = str(request.url_for("download_file", filename=result_name)) logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) return {"result": result_name, "url": url} @app.post("/inpaint-multipart") def inpaint_multipart( image: UploadFile = File(...), mask: UploadFile = File(...), request: Request = None, invert_mask: bool = True, mask_is_painted: bool = False, # if True, mask file is the painted-on image (e.g., black strokes on original) passthrough: bool = False, _: None = Depends(bearer_auth), ) -> Dict[str, str]: # Load in-memory img = Image.open(image.file).convert("RGBA") m = Image.open(mask.file).convert("RGBA") if passthrough: # Just echo the input image, ignore mask result = np.array(img.convert("RGB")) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} if url: entry["url"] = url logs.append(entry) resp: Dict[str, str] = {"result": result_name} if url: resp["url"] = url return resp if mask_is_painted: # Derive mask by differencing painted image vs original img_rgb = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2RGB) m_rgb = cv2.cvtColor(np.array(m), cv2.COLOR_RGBA2RGB) diff = cv2.absdiff(img_rgb, m_rgb) gray = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY) # Otsu threshold for robustness; fallback threshold if Otsu fails try: _, binmask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) except Exception: _, binmask = cv2.threshold(gray, 40, 255, cv2.THRESH_BINARY) nonzero = int((binmask > 0).sum()) log.info("painted-mask via diff, pixels: %d", nonzero) # If nothing detected (user uploaded plain original as mask), detect dark strokes directly if nonzero < 50: gray_painted = cv2.cvtColor(m_rgb, cv2.COLOR_RGB2GRAY) # pick very dark paint _, binmask = cv2.threshold(gray_painted, 60, 255, cv2.THRESH_BINARY_INV) nonzero = int((binmask > 0).sum()) log.info("painted-mask via dark-thresh, pixels: %d", nonzero) # Build RGBA mask where selected area has alpha=0 mask_rgba = np.zeros((binmask.shape[0], binmask.shape[1], 4), dtype=np.uint8) mask_rgba[:, :, 3] = np.where(binmask > 0, 0, 255).astype(np.uint8) else: mask_rgba = _load_rgba_mask_from_image(m) result = process_inpaint(np.array(img), mask_rgba, invert_mask=invert_mask) result_name = f"output_{uuid.uuid4().hex}.png" result_path = os.path.join(OUTPUT_DIR, result_name) Image.fromarray(result).save(result_path) url: Optional[str] = None try: if request is not None: url = str(request.url_for("download_file", filename=result_name)) except Exception: url = None entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} if url: entry["url"] = url logs.append(entry) resp: Dict[str, str] = {"result": result_name} if url: resp["url"] = url return resp @app.get("/download/{filename}") def download_file(filename: str): path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path) @app.get("/result/{filename}") def view_result(filename: str): """View result image directly in browser (same as download but with proper content-type for viewing)""" path = os.path.join(OUTPUT_DIR, filename) if not os.path.isfile(path): raise HTTPException(status_code=404, detail="file not found") return FileResponse(path, media_type="image/png") @app.get("/logs") def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: return JSONResponse(content=logs)