File size: 9,977 Bytes
7a955d9
 
a520e24
7a955d9
 
 
 
 
 
a520e24
4b4508d
 
 
 
 
 
 
0fd380e
4b4508d
4b9cc62
ec15dbc
0fd380e
ec15dbc
0fd380e
 
 
ec15dbc
 
7a955d9
 
 
 
4b9cc62
7a955d9
 
 
 
 
 
 
4b9cc62
7a955d9
 
0fd380e
7a955d9
4b9cc62
ec15dbc
4b9cc62
7a955d9
4b9cc62
7a955d9
 
 
4b9cc62
 
 
 
 
 
 
 
 
 
 
 
 
 
ec15dbc
4b9cc62
 
 
7a955d9
 
0fd380e
7a955d9
ec15dbc
 
c0b300b
 
4b9cc62
ec15dbc
7a955d9
 
 
 
 
 
 
 
ec15dbc
4b9cc62
 
 
 
 
ec15dbc
7a955d9
 
 
 
0fd380e
7a955d9
4b4508d
 
 
 
 
 
0fd380e
 
4b4508d
 
4b9cc62
a520e24
4b4508d
 
 
0fd380e
a520e24
d8fb1f1
ec15dbc
7a955d9
4b4508d
 
 
 
 
7a955d9
ec15dbc
0fd380e
ec15dbc
4b4508d
 
7a955d9
c0b300b
 
4b4508d
0fd380e
7a955d9
0fd380e
4b4508d
ec15dbc
 
 
4b4508d
7a955d9
4b4508d
 
 
 
7a955d9
4b4508d
 
 
0fd380e
7a955d9
4b4508d
 
4d96937
4b4508d
7a955d9
 
a520e24
8c09773
7a955d9
4b9cc62
7a955d9
 
a520e24
d8fb1f1
7a955d9
d8fb1f1
7a955d9
d8fb1f1
 
7a955d9
a520e24
7a955d9
 
d8fb1f1
a520e24
 
7a955d9
a520e24
7a955d9
4b4508d
 
 
c0b300b
4b4508d
c0b300b
4b4508d
 
 
c0b300b
4b9cc62
4b4508d
 
0fd380e
c0b300b
4b4508d
 
0fd380e
 
4b4508d
c0b300b
4b4508d
 
0fd380e
 
4b4508d
0fd380e
4b4508d
 
 
 
c0b300b
4b4508d
0fd380e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b4508d
0fd380e
 
 
 
c0b300b
4b4508d
0fd380e
 
 
 
 
 
 
 
 
4b4508d
 
0fd380e
c0b300b
4b9cc62
 
0fd380e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os
os.environ["OMP_NUM_THREADS"] = "1"
import gradio as gr
import cv2
import shutil
import uuid
import insightface
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
import subprocess
import numpy as np
import threading
from fastapi import FastAPI, UploadFile, File, HTTPException, Response
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
from motor.motor_asyncio import AsyncIOMotorClient
from bson.objectid import ObjectId
from gridfs import AsyncIOMotorGridFSBucket
from gradio import mount_gradio_app
import uvicorn
import logging
import io

# -------------------------------------------------
# Logging
# -------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -------------------------------------------------
# Paths
# -------------------------------------------------
REPO_ID = "HariLogicgo/face_swap_models"
BASE_DIR = "./workspace"
UPLOAD_DIR = os.path.join(BASE_DIR, "uploads")
RESULT_DIR = os.path.join(BASE_DIR, "results")
MODELS_DIR = "./models"

os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

# -------------------------------------------------
# Download models
# -------------------------------------------------
def download_models():
    logger.info("Downloading models...")
    inswapper_path = hf_hub_download(
        repo_id=REPO_ID,
        filename="models/inswapper_128.onnx",
        repo_type="model",
        local_dir=MODELS_DIR
    )
    buffalo_files = [
        "1k3d68.onnx",
        "2d106det.onnx",
        "genderage.onnx",
        "det_10g.onnx",
        "w600k_r50.onnx"
    ]
    for f in buffalo_files:
        hf_hub_download(
            repo_id=REPO_ID,
            filename=f"models/buffalo_l/{f}",
            repo_type="model",
            local_dir=MODELS_DIR
        )
    logger.info("Models downloaded successfully")
    return inswapper_path

inswapper_path = download_models()

# -------------------------------------------------
# Face Analysis + Swapper
# -------------------------------------------------
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
logger.info(f"Initializing FaceAnalysis with providers: {providers}")
face_analysis_app = FaceAnalysis(name="buffalo_l", root=MODELS_DIR, providers=providers)
face_analysis_app.prepare(ctx_id=0, det_size=(640, 640))
swapper = insightface.model_zoo.get_model(inswapper_path, providers=providers)
logger.info("FaceAnalysis and swapper initialized")

# -------------------------------------------------
# CodeFormer setup
# -------------------------------------------------
CODEFORMER_PATH = "CodeFormer/inference_codeformer.py"

def ensure_codeformer():
    if not os.path.exists("CodeFormer"):
        logger.info("Cloning CodeFormer repository...")
        subprocess.run("git clone https://github.com/sczhou/CodeFormer.git", shell=True, check=True)
        subprocess.run("pip install -r CodeFormer/requirements.txt", shell=True, check=True)
        subprocess.run("python CodeFormer/basicsr/setup.py develop", shell=True, check=True)
        subprocess.run("python CodeFormer/scripts/download_pretrained_models.py facelib", shell=True, check=True)
        subprocess.run("python CodeFormer/scripts/download_pretrained_models.py CodeFormer", shell=True, check=True)
        logger.info("CodeFormer setup complete")

ensure_codeformer()

# -------------------------------------------------
# MongoDB + GridFS
# -------------------------------------------------
MONGODB_URL = os.getenv(
    "MONGODB_URL",
    "mongodb+srv://harilogicgo_db_user:[email protected]/?retryWrites=true&w=majority&appName=Cluster0"
)
client = AsyncIOMotorClient(MONGODB_URL)
database = client.FaceSwap
fs_bucket = AsyncIOMotorGridFSBucket(database)
logger.info("MongoDB + GridFS initialized")

# -------------------------------------------------
# Lock for face swap
# -------------------------------------------------
swap_lock = threading.Lock()

# -------------------------------------------------
# Face Swap Pipeline
# -------------------------------------------------
def face_swap_and_enhance(src_img, tgt_img):
    logger.info("Starting face swap and enhancement")
    try:
        with swap_lock:
            shutil.rmtree(UPLOAD_DIR, ignore_errors=True)
            shutil.rmtree(RESULT_DIR, ignore_errors=True)
            os.makedirs(UPLOAD_DIR, exist_ok=True)
            os.makedirs(RESULT_DIR, exist_ok=True)

            if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray):
                return None, None, "❌ Invalid input images"

            src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
            tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)

            src_faces = face_analysis_app.get(src_bgr)
            tgt_faces = face_analysis_app.get(tgt_bgr)
            if not src_faces or not tgt_faces:
                return None, None, "❌ Face not detected"

            swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
            swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
            if swapped_bgr is None:
                return None, None, "❌ Face swap failed"

            cv2.imwrite(swapped_path, swapped_bgr)

            cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample"
            result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
            if result.returncode != 0:
                return None, None, f"❌ CodeFormer failed:\n{result.stderr}"

            final_results_dir = os.path.join(RESULT_DIR, "final_results")
            final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
            if not final_files:
                return None, None, "❌ No enhanced image found"

            final_path = os.path.join(final_results_dir, final_files[0])
            final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)

            return final_img, final_path, ""

    except Exception as e:
        return None, None, f"❌ Error: {str(e)}"

# -------------------------------------------------
# Gradio Interface
# -------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("Face Swap")

    with gr.Row():
        src_input = gr.Image(type="numpy", label="Upload Your Face")
        tgt_input = gr.Image(type="numpy", label="Upload Target Image")

    btn = gr.Button("Swap Face")
    output_img = gr.Image(type="numpy", label="Enhanced Output")
    download = gr.File(label="⬇️ Download Enhanced Image")
    error_box = gr.Textbox(label="Logs / Errors", interactive=False)

    def process(src, tgt):
        img, path, err = face_swap_and_enhance(src, tgt)
        return img, path, err

    btn.click(process, [src_input, tgt_input], [output_img, download, error_box])

# -------------------------------------------------
# FastAPI App
# -------------------------------------------------
fastapi_app = FastAPI()

@fastapi_app.get("/")
def root():
    return RedirectResponse("/gradio")

@fastapi_app.get("/health")
async def health():
    return {"status": "healthy"}

# -------- Upload Endpoints with GridFS --------
@fastapi_app.post("/source")
async def upload_source(image: UploadFile = File(...)):
    contents = await image.read()
    file_id = await fs_bucket.upload_from_stream(image.filename, contents)
    return {"source_id": str(file_id)}

@fastapi_app.post("/target")
async def upload_target(image: UploadFile = File(...)):
    contents = await image.read()
    file_id = await fs_bucket.upload_from_stream(image.filename, contents)
    return {"target_id": str(file_id)}

# -------- Faceswap Endpoint --------
class FaceSwapRequest(BaseModel):
    source_id: str
    target_id: str

@fastapi_app.post("/faceswap")
async def perform_faceswap(request: FaceSwapRequest):
    try:
        # Read source
        source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
        source_bytes = await source_stream.read()
        source_array = np.frombuffer(source_bytes, np.uint8)
        source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR)
        source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)

        # Read target
        target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id))
        target_bytes = await target_stream.read()
        target_array = np.frombuffer(target_bytes, np.uint8)
        target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR)
        target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB)

        # Run pipeline
        final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb)
        if err:
            raise HTTPException(status_code=500, detail=err)

        # Store result in GridFS
        with open(final_path, "rb") as f:
            final_bytes = f.read()
        result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes)

        return {"result_id": str(result_id)}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# -------- Download Endpoint --------
@fastapi_app.get("/download/{result_id}")
async def download_result(result_id: str):
    try:
        stream = await fs_bucket.open_download_stream(ObjectId(result_id))
        file_data = await stream.read()
        return Response(
            content=file_data,
            media_type="image/png",
            headers={"Content-Disposition": f"attachment; filename=enhanced.png"}
        )
    except Exception:
        raise HTTPException(status_code=404, detail="Result not found")

# Mount Gradio
fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")

if __name__ == "__main__":
    uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)