Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import uuid | |
| import shutil | |
| from datetime import datetime | |
| from typing import Dict, List, Optional | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| from src.core import process_inpaint | |
| from pymongo import MongoClient | |
| # Connect to MongoDB (URI must be in HF env variables) | |
| MONGO_URI = os.environ.get("MONGO_URI") | |
| mongo_client = None | |
| mongo_collection = None | |
| if MONGO_URI: | |
| try: | |
| mongo_client = MongoClient(MONGO_URI) | |
| mongo_db = mongo_client["object-remove-logs"] | |
| mongo_collection = mongo_db["api_logs"] | |
| except Exception as e: | |
| print("Mongo connection failed:", e) | |
| # Directories (use writable space on HF Spaces) | |
| BASE_DIR = os.environ.get("DATA_DIR", "/data") | |
| if not os.path.isdir(BASE_DIR): | |
| # Fallback to /tmp if /data not available | |
| BASE_DIR = "/tmp" | |
| UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") | |
| OUTPUT_DIR = os.path.join(BASE_DIR, "outputs") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open | |
| ENV_TOKEN = os.environ.get("API_TOKEN") | |
| app = FastAPI(title="Photo Object Removal API", version="1.0.0") | |
| # In-memory stores | |
| file_store: Dict[str, Dict[str, str]] = {} | |
| logs: List[Dict[str, str]] = [] | |
| def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None: | |
| if not ENV_TOKEN: | |
| return | |
| if authorization is None or not authorization.lower().startswith("bearer "): | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| token = authorization.split(" ", 1)[1] | |
| if token != ENV_TOKEN: | |
| raise HTTPException(status_code=403, detail="Forbidden") | |
| class InpaintRequest(BaseModel): | |
| image_id: str | |
| mask_id: str | |
| def root() -> Dict[str, object]: | |
| return { | |
| "name": "Photo Object Removal API", | |
| "status": "ok", | |
| "endpoints": { | |
| "GET /health": "health check", | |
| "POST /upload-image": "form-data: image=file", | |
| "POST /upload-mask": "form-data: mask=file", | |
| "POST /inpaint": "JSON: {image_id, mask_id}", | |
| "POST /inpaint-multipart": "form-data: image=file, mask=file", | |
| "GET /download/{filename}": "download result image", | |
| "GET /result/{filename}": "view result image in browser", | |
| "GET /logs": "recent uploads/results", | |
| }, | |
| "auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)", | |
| } | |
| def health() -> Dict[str, str]: | |
| return {"status": "healthy"} | |
| def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| ext = os.path.splitext(image.filename)[1] or ".png" | |
| file_id = str(uuid.uuid4()) | |
| stored_name = f"{file_id}{ext}" | |
| stored_path = os.path.join(UPLOAD_DIR, stored_name) | |
| with open(stored_path, "wb") as f: | |
| shutil.copyfileobj(image.file, f) | |
| file_store[file_id] = { | |
| "type": "image", | |
| "filename": image.filename, | |
| "stored_name": stored_name, | |
| "path": stored_path, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()}) | |
| return {"id": file_id, "filename": image.filename} | |
| def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| ext = os.path.splitext(mask.filename)[1] or ".png" | |
| file_id = str(uuid.uuid4()) | |
| stored_name = f"{file_id}{ext}" | |
| stored_path = os.path.join(UPLOAD_DIR, stored_name) | |
| with open(stored_path, "wb") as f: | |
| shutil.copyfileobj(mask.file, f) | |
| file_store[file_id] = { | |
| "type": "mask", | |
| "filename": mask.filename, | |
| "stored_name": stored_name, | |
| "path": stored_path, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()}) | |
| return {"id": file_id, "filename": mask.filename} | |
| def _load_rgba_image(path: str) -> Image.Image: | |
| img = Image.open(path) | |
| return img.convert("RGBA") | |
| def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray: | |
| # Expected by process_inpaint: RGBA where alpha=0 for drawn (to remove), 255 elsewhere | |
| if img.mode != "RGBA": | |
| # If no alpha, treat non-black/white>0 as masked areas | |
| gray = img.convert("L") | |
| arr = np.array(gray) | |
| alpha = np.where(arr > 0, 0, 255).astype(np.uint8) | |
| rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8) | |
| rgba[:, :, 3] = alpha | |
| return rgba | |
| return np.array(img) | |
| def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": | |
| raise HTTPException(status_code=404, detail="image_id not found") | |
| if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": | |
| raise HTTPException(status_code=404, detail="mask_id not found") | |
| img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) | |
| mask_img = Image.open(file_store[req.mask_id]["path"]) | |
| mask_rgba = _load_rgba_mask_from_image(mask_img) | |
| result = process_inpaint(np.array(img_rgba), mask_rgba) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()}) | |
| # Mongo logging | |
| if mongo_collection: | |
| mongo_collection.insert_one({ | |
| "endpoint": "/inpaint", | |
| "result": result_name, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| return {"result": result_name} | |
| def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]: | |
| """Same as /inpaint but returns a JSON with a public download URL instead of image bytes.""" | |
| if req.image_id not in file_store or file_store[req.image_id]["type"] != "image": | |
| raise HTTPException(status_code=404, detail="image_id not found") | |
| if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask": | |
| raise HTTPException(status_code=404, detail="mask_id not found") | |
| img_rgba = _load_rgba_image(file_store[req.image_id]["path"]) | |
| mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA | |
| mask_rgba = _load_rgba_mask_from_image(mask_img) | |
| result = process_inpaint(np.array(img_rgba), mask_rgba) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()}) | |
| return {"result": result_name, "url": url} | |
| def inpaint_multipart( | |
| image: UploadFile = File(...), | |
| mask: UploadFile = File(...), | |
| request: Request = None, | |
| _: None = Depends(bearer_auth), | |
| ) -> Dict[str, str]: | |
| img = Image.open(image.file).convert("RGBA") | |
| m = Image.open(mask.file) | |
| mask_rgba = _load_rgba_mask_from_image(m) | |
| result = process_inpaint(np.array(img), mask_rgba) | |
| result_name = f"output_{uuid.uuid4().hex}.png" | |
| result_path = os.path.join(OUTPUT_DIR, result_name) | |
| Image.fromarray(result).save(result_path) | |
| url: Optional[str] = None | |
| try: | |
| if request is not None: | |
| url = str(request.url_for("download_file", filename=result_name)) | |
| except Exception: | |
| url = None | |
| entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()} | |
| if url: | |
| entry["url"] = url | |
| logs.append(entry) | |
| # Mongo logging | |
| if mongo_collection: | |
| mongo_collection.insert_one({ | |
| "endpoint": "/inpaint-multipart", | |
| "result": result_name, | |
| "url": url, | |
| "timestamp": datetime.utcnow().isoformat() | |
| }) | |
| resp: Dict[str, str] = {"result": result_name} | |
| if url: | |
| resp["url"] = url | |
| return resp | |
| def download_file(filename: str): | |
| path = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.isfile(path): | |
| raise HTTPException(status_code=404, detail="file not found") | |
| return FileResponse(path) | |
| def view_result(filename: str): | |
| """View result image directly in browser (same as download but with proper content-type for viewing)""" | |
| path = os.path.join(OUTPUT_DIR, filename) | |
| if not os.path.isfile(path): | |
| raise HTTPException(status_code=404, detail="file not found") | |
| return FileResponse(path, media_type="image/png") | |
| def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse: | |
| return JSONResponse(content=logs) | |