object_remover / api /main.py
LogicGoInfotechSpaces's picture
Update api/main.py
68580bd verified
raw
history blame
9.25 kB
import os
import uuid
import shutil
from datetime import datetime
from typing import Dict, List, Optional
import numpy as np
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header, Request
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from PIL import Image
from src.core import process_inpaint
from pymongo import MongoClient
# Connect to MongoDB (URI must be in HF env variables)
MONGO_URI = os.environ.get("MONGO_URI")
mongo_client = None
mongo_collection = None
if MONGO_URI:
try:
mongo_client = MongoClient(MONGO_URI)
mongo_db = mongo_client["object-remove-logs"]
mongo_collection = mongo_db["api_logs"]
except Exception as e:
print("Mongo connection failed:", e)
# Directories (use writable space on HF Spaces)
BASE_DIR = os.environ.get("DATA_DIR", "/data")
if not os.path.isdir(BASE_DIR):
# Fallback to /tmp if /data not available
BASE_DIR = "/tmp"
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Optional Bearer token: set env API_TOKEN to require auth; if not set, endpoints are open
ENV_TOKEN = os.environ.get("API_TOKEN")
app = FastAPI(title="Photo Object Removal API", version="1.0.0")
# In-memory stores
file_store: Dict[str, Dict[str, str]] = {}
logs: List[Dict[str, str]] = []
def bearer_auth(authorization: Optional[str] = Header(default=None)) -> None:
if not ENV_TOKEN:
return
if authorization is None or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = authorization.split(" ", 1)[1]
if token != ENV_TOKEN:
raise HTTPException(status_code=403, detail="Forbidden")
class InpaintRequest(BaseModel):
image_id: str
mask_id: str
@app.get("/")
def root() -> Dict[str, object]:
return {
"name": "Photo Object Removal API",
"status": "ok",
"endpoints": {
"GET /health": "health check",
"POST /upload-image": "form-data: image=file",
"POST /upload-mask": "form-data: mask=file",
"POST /inpaint": "JSON: {image_id, mask_id}",
"POST /inpaint-multipart": "form-data: image=file, mask=file",
"GET /download/{filename}": "download result image",
"GET /result/{filename}": "view result image in browser",
"GET /logs": "recent uploads/results",
},
"auth": "set API_TOKEN env var to require Authorization: Bearer <token> (except /health)",
}
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "healthy"}
@app.post("/upload-image")
def upload_image(image: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
ext = os.path.splitext(image.filename)[1] or ".png"
file_id = str(uuid.uuid4())
stored_name = f"{file_id}{ext}"
stored_path = os.path.join(UPLOAD_DIR, stored_name)
with open(stored_path, "wb") as f:
shutil.copyfileobj(image.file, f)
file_store[file_id] = {
"type": "image",
"filename": image.filename,
"stored_name": stored_name,
"path": stored_path,
"timestamp": datetime.utcnow().isoformat(),
}
logs.append({"id": file_id, "filename": image.filename, "type": "image", "timestamp": datetime.utcnow().isoformat()})
return {"id": file_id, "filename": image.filename}
@app.post("/upload-mask")
def upload_mask(mask: UploadFile = File(...), _: None = Depends(bearer_auth)) -> Dict[str, str]:
ext = os.path.splitext(mask.filename)[1] or ".png"
file_id = str(uuid.uuid4())
stored_name = f"{file_id}{ext}"
stored_path = os.path.join(UPLOAD_DIR, stored_name)
with open(stored_path, "wb") as f:
shutil.copyfileobj(mask.file, f)
file_store[file_id] = {
"type": "mask",
"filename": mask.filename,
"stored_name": stored_name,
"path": stored_path,
"timestamp": datetime.utcnow().isoformat(),
}
logs.append({"id": file_id, "filename": mask.filename, "type": "mask", "timestamp": datetime.utcnow().isoformat()})
return {"id": file_id, "filename": mask.filename}
def _load_rgba_image(path: str) -> Image.Image:
img = Image.open(path)
return img.convert("RGBA")
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
# Expected by process_inpaint: RGBA where alpha=0 for drawn (to remove), 255 elsewhere
if img.mode != "RGBA":
# If no alpha, treat non-black/white>0 as masked areas
gray = img.convert("L")
arr = np.array(gray)
alpha = np.where(arr > 0, 0, 255).astype(np.uint8)
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
rgba[:, :, 3] = alpha
return rgba
return np.array(img)
@app.post("/inpaint")
def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, str]:
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
raise HTTPException(status_code=404, detail="image_id not found")
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
raise HTTPException(status_code=404, detail="mask_id not found")
img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
mask_img = Image.open(file_store[req.mask_id]["path"])
mask_rgba = _load_rgba_mask_from_image(mask_img)
result = process_inpaint(np.array(img_rgba), mask_rgba)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()})
# Mongo logging
if mongo_collection:
mongo_collection.insert_one({
"endpoint": "/inpaint",
"result": result_name,
"timestamp": datetime.utcnow().isoformat()
})
return {"result": result_name}
@app.post("/inpaint-url")
def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_auth)) -> Dict[str, str]:
"""Same as /inpaint but returns a JSON with a public download URL instead of image bytes."""
if req.image_id not in file_store or file_store[req.image_id]["type"] != "image":
raise HTTPException(status_code=404, detail="image_id not found")
if req.mask_id not in file_store or file_store[req.mask_id]["type"] != "mask":
raise HTTPException(status_code=404, detail="mask_id not found")
img_rgba = _load_rgba_image(file_store[req.image_id]["path"])
mask_img = Image.open(file_store[req.mask_id]["path"]) # may be RGB/gray/RGBA
mask_rgba = _load_rgba_mask_from_image(mask_img)
result = process_inpaint(np.array(img_rgba), mask_rgba)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url = str(request.url_for("download_file", filename=result_name))
logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()})
return {"result": result_name, "url": url}
@app.post("/inpaint-multipart")
def inpaint_multipart(
image: UploadFile = File(...),
mask: UploadFile = File(...),
request: Request = None,
_: None = Depends(bearer_auth),
) -> Dict[str, str]:
img = Image.open(image.file).convert("RGBA")
m = Image.open(mask.file)
mask_rgba = _load_rgba_mask_from_image(m)
result = process_inpaint(np.array(img), mask_rgba)
result_name = f"output_{uuid.uuid4().hex}.png"
result_path = os.path.join(OUTPUT_DIR, result_name)
Image.fromarray(result).save(result_path)
url: Optional[str] = None
try:
if request is not None:
url = str(request.url_for("download_file", filename=result_name))
except Exception:
url = None
entry: Dict[str, str] = {"result": result_name, "timestamp": datetime.utcnow().isoformat()}
if url:
entry["url"] = url
logs.append(entry)
# Mongo logging
if mongo_collection:
mongo_collection.insert_one({
"endpoint": "/inpaint-multipart",
"result": result_name,
"url": url,
"timestamp": datetime.utcnow().isoformat()
})
resp: Dict[str, str] = {"result": result_name}
if url:
resp["url"] = url
return resp
@app.get("/download/{filename}")
def download_file(filename: str):
path = os.path.join(OUTPUT_DIR, filename)
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail="file not found")
return FileResponse(path)
@app.get("/result/{filename}")
def view_result(filename: str):
"""View result image directly in browser (same as download but with proper content-type for viewing)"""
path = os.path.join(OUTPUT_DIR, filename)
if not os.path.isfile(path):
raise HTTPException(status_code=404, detail="file not found")
return FileResponse(path, media_type="image/png")
@app.get("/logs")
def get_logs(_: None = Depends(bearer_auth)) -> JSONResponse:
return JSONResponse(content=logs)