import os os.environ["OMP_NUM_THREADS"] = "1" import gradio as gr import cv2 import shutil import uuid import insightface from insightface.app import FaceAnalysis from huggingface_hub import hf_hub_download import subprocess import numpy as np import threading from fastapi import FastAPI, UploadFile, File, HTTPException, Response from fastapi.responses import RedirectResponse from pydantic import BaseModel from motor.motor_asyncio import AsyncIOMotorClient from bson.objectid import ObjectId from gridfs import AsyncIOMotorGridFSBucket from gradio import mount_gradio_app import uvicorn import logging import io # ------------------------------------------------- # Logging # ------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ------------------------------------------------- # Paths # ------------------------------------------------- REPO_ID = "HariLogicgo/face_swap_models" BASE_DIR = "./workspace" UPLOAD_DIR = os.path.join(BASE_DIR, "uploads") RESULT_DIR = os.path.join(BASE_DIR, "results") MODELS_DIR = "./models" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULT_DIR, exist_ok=True) os.makedirs(MODELS_DIR, exist_ok=True) # ------------------------------------------------- # Download models # ------------------------------------------------- def download_models(): logger.info("Downloading models...") inswapper_path = hf_hub_download( repo_id=REPO_ID, filename="models/inswapper_128.onnx", repo_type="model", local_dir=MODELS_DIR ) buffalo_files = [ "1k3d68.onnx", "2d106det.onnx", "genderage.onnx", "det_10g.onnx", "w600k_r50.onnx" ] for f in buffalo_files: hf_hub_download( repo_id=REPO_ID, filename=f"models/buffalo_l/{f}", repo_type="model", local_dir=MODELS_DIR ) logger.info("Models downloaded successfully") return inswapper_path inswapper_path = download_models() # ------------------------------------------------- # Face Analysis + Swapper # ------------------------------------------------- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] logger.info(f"Initializing FaceAnalysis with providers: {providers}") face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers) face_analysis_app.prepare(ctx_id=0, det_size=(640, 640)) swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers) logger.info("FaceAnalysis and swapper initialized") # ------------------------------------------------- # CodeFormer setup # ------------------------------------------------- CODEFORMER_PATH = "CodeFormer/inference_codeformer.py" def ensure_codeformer(): if not os.path.exists("CodeFormer"): logger.info("Cloning CodeFormer repository...") subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True) subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True) subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True) subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True) subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True) logger.info("CodeFormer setup complete") ensure_codeformer() # ------------------------------------------------- # MongoDB + GridFS # ------------------------------------------------- MONGODB_URL = os.getenv( "MONGODB_URL", "mongodb+srv://harilogicgo_db_user:logicgoinfotech@cluster0.dcs1tnb.mongodb.net/?retryWrites=true&w=majority&appName=Cluster0" ) client = AsyncIOMotorClient(MONGODB_URL) database = client.FaceSwap fs_bucket = AsyncIOMotorGridFSBucket(database) logger.info("MongoDB + GridFS initialized") # ------------------------------------------------- # Lock for face swap # ------------------------------------------------- swap_lock = threading.Lock() # ------------------------------------------------- # Face Swap Pipeline # ------------------------------------------------- def face_swap_and_enhance(src_img, tgt_img): logger.info("Starting face swap and enhancement") try: with swap_lock: shutil.rmtree(UPLOAD_DIR, ignore_errors=True) shutil.rmtree(RESULT_DIR, ignore_errors=True) os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULT_DIR, exist_ok=True) if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray): return None, None, "❌ Invalid input images" src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR) tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR) src_faces = face_analysis_app.get(src_bgr) tgt_faces = face_analysis_app.get(tgt_bgr) if not src_faces or not tgt_faces: return None, None, "❌ Face not detected" swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg") swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0]) if swapped_bgr is None: return None, None, "❌ Face swap failed" cv2.imwrite(swapped_path, swapped_bgr) cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample" result = subprocess.run(cmd, shell=True, capture_output=True, text=True) if result.returncode != 0: return None, None, f"❌ CodeFormer failed:\n{result.stderr}" final_results_dir = os.path.join(RESULT_DIR, "final_results") final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")] if not final_files: return None, None, "❌ No enhanced image found" final_path = os.path.join(final_results_dir, final_files[0]) final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB) return final_img, final_path, "" except Exception as e: return None, None, f"❌ Error: {str(e)}" # ------------------------------------------------- # Gradio Interface # ------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("Face Swap") with gr.Row(): src_input = gr.Image(type="numpy", label="Upload Your Face") tgt_input = gr.Image(type="numpy", label="Upload Target Image") btn = gr.Button("Swap Face") output_img = gr.Image(type="numpy", label="Enhanced Output") download = gr.File(label="⬇️ Download Enhanced Image") error_box = gr.Textbox(label="Logs / Errors", interactive=False) def process(src, tgt): img, path, err = face_swap_and_enhance(src, tgt) return img, path, err btn.click(process, [src_input, tgt_input], [output_img, download, error_box]) # ------------------------------------------------- # FastAPI App # ------------------------------------------------- fastapi_app = FastAPI() @fastapi_app.get("/") def root(): return RedirectResponse("/gradio") @fastapi_app.get("/health") async def health(): return {"status": "healthy"} # -------- Upload Endpoints with GridFS -------- @fastapi_app.post("/source") async def upload_source(image: UploadFile = File(...)): contents = await image.read() file_id = await fs_bucket.upload_from_stream(image.filename, contents) return {"source_id": str(file_id)} @fastapi_app.post("/target") async def upload_target(image: UploadFile = File(...)): contents = await image.read() file_id = await fs_bucket.upload_from_stream(image.filename, contents) return {"target_id": str(file_id)} # -------- Faceswap Endpoint -------- class FaceSwapRequest(BaseModel): source_id: str target_id: str @fastapi_app.post("/faceswap") async def perform_faceswap(request: FaceSwapRequest): try: # Read source source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id)) source_bytes = await source_stream.read() source_array = np.frombuffer(source_bytes, np.uint8) source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR) source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB) # Read target target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id)) target_bytes = await target_stream.read() target_array = np.frombuffer(target_bytes, np.uint8) target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR) target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB) # Run pipeline final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb) if err: raise HTTPException(status_code=500, detail=err) # Store result in GridFS with open(final_path, "rb") as f: final_bytes = f.read() result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes) return {"result_id": str(result_id)} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # -------- Download Endpoint -------- @fastapi_app.get("/download/{result_id}") async def download_result(result_id: str): try: stream = await fs_bucket.open_download_stream(ObjectId(result_id)) file_data = await stream.read() return Response( content=file_data, media_type="image/png", headers={"Content-Disposition": f"attachment; filename=enhanced.png"} ) except Exception: raise HTTPException(status_code=404, detail="Result not found") # Mount Gradio fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio") if __name__ == "__main__": uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)