HariLogicgo commited on
Commit
0fd380e
·
1 Parent(s): c0b300b

usig gridfs for storage

Browse files
Files changed (2) hide show
  1. app.py +62 -103
  2. test.py +41 -20
app.py CHANGED
@@ -15,11 +15,15 @@ from fastapi.responses import RedirectResponse
15
  from pydantic import BaseModel
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from bson.objectid import ObjectId
 
18
  from gradio import mount_gradio_app
19
  import uvicorn
20
  import logging
 
21
 
22
- # Set up logging
 
 
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
@@ -37,7 +41,7 @@ os.makedirs(RESULT_DIR, exist_ok=True)
37
  os.makedirs(MODELS_DIR, exist_ok=True)
38
 
39
  # -------------------------------------------------
40
- # Download models once
41
  # -------------------------------------------------
42
  def download_models():
43
  logger.info("Downloading models...")
@@ -67,7 +71,7 @@ def download_models():
67
  inswapper_path = download_models()
68
 
69
  # -------------------------------------------------
70
- # Initialize face analysis and swapper
71
  # -------------------------------------------------
72
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
73
  logger.info(f"Initializing FaceAnalysis with providers: {providers}")
@@ -94,7 +98,7 @@ def ensure_codeformer():
94
  ensure_codeformer()
95
 
96
  # -------------------------------------------------
97
- # MongoDB setup
98
  # -------------------------------------------------
99
  MONGODB_URL = os.getenv(
100
  "MONGODB_URL",
@@ -102,10 +106,8 @@ MONGODB_URL = os.getenv(
102
  )
103
  client = AsyncIOMotorClient(MONGODB_URL)
104
  database = client.FaceSwap
105
- target_images_collection = database.get_collection("Target_Images")
106
- source_images_collection = database.get_collection("Source_Images")
107
- results_collection = database.get_collection("Results")
108
- logger.info("MongoDB client initialized")
109
 
110
  # -------------------------------------------------
111
  # Lock for face swap
@@ -113,7 +115,7 @@ logger.info("MongoDB client initialized")
113
  swap_lock = threading.Lock()
114
 
115
  # -------------------------------------------------
116
- # Pipeline Function
117
  # -------------------------------------------------
118
  def face_swap_and_enhance(src_img, tgt_img):
119
  logger.info("Starting face swap and enhancement")
@@ -125,55 +127,39 @@ def face_swap_and_enhance(src_img, tgt_img):
125
  os.makedirs(RESULT_DIR, exist_ok=True)
126
 
127
  if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray):
128
- logger.error("Invalid input images: not numpy arrays")
129
- return None, None, "❌ Invalid input images: not numpy arrays"
130
 
131
  src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
132
  tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
133
 
134
- logger.info("Detecting faces...")
135
  src_faces = face_analysis_app.get(src_bgr)
136
  tgt_faces = face_analysis_app.get(tgt_bgr)
137
  if not src_faces or not tgt_faces:
138
- logger.error("Face not detected in one of the images")
139
- return None, None, "❌ Face not detected in one of the images"
140
 
141
- unique_name = f"swapped_{uuid.uuid4().hex[:8]}.jpg"
142
- swapped_path = os.path.join(UPLOAD_DIR, unique_name)
143
- logger.info("Performing face swap...")
144
  swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
145
  if swapped_bgr is None:
146
- logger.error("Face swap failed: swapper returned None")
147
  return None, None, "❌ Face swap failed"
148
 
149
  cv2.imwrite(swapped_path, swapped_bgr)
150
- logger.info(f"Swapped image saved to {swapped_path}")
151
 
152
  cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample"
153
- logger.info(f"Running CodeFormer: {cmd}")
154
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
155
  if result.returncode != 0:
156
- logger.error(f"CodeFormer failed: {result.stderr}")
157
  return None, None, f"❌ CodeFormer failed:\n{result.stderr}"
158
 
159
  final_results_dir = os.path.join(RESULT_DIR, "final_results")
160
- if not os.path.exists(final_results_dir):
161
- logger.error("CodeFormer did not produce final results")
162
- return None, None, "❌ CodeFormer did not produce final results"
163
-
164
  final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
165
  if not final_files:
166
- logger.error("No enhanced image found in final results")
167
- return None, None, "❌ No enhanced image found in final results"
168
 
169
  final_path = os.path.join(final_results_dir, final_files[0])
170
  final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
171
- logger.info(f"Enhanced image ready at {final_path}")
172
 
173
  return final_img, final_path, ""
174
 
175
  except Exception as e:
176
- logger.error(f"Face swap error: {str(e)}")
177
  return None, None, f"❌ Error: {str(e)}"
178
 
179
  # -------------------------------------------------
@@ -187,7 +173,6 @@ with gr.Blocks() as demo:
187
  tgt_input = gr.Image(type="numpy", label="Upload Target Image")
188
 
189
  btn = gr.Button("Swap Face")
190
-
191
  output_img = gr.Image(type="numpy", label="Enhanced Output")
192
  download = gr.File(label="⬇️ Download Enhanced Image")
193
  error_box = gr.Textbox(label="Logs / Errors", interactive=False)
@@ -211,98 +196,72 @@ def root():
211
  async def health():
212
  return {"status": "healthy"}
213
 
 
214
  @fastapi_app.post("/source")
215
  async def upload_source(image: UploadFile = File(...)):
216
- logger.info(f"Uploading source image: {image.filename}")
217
  contents = await image.read()
218
- doc = {
219
- "filename": image.filename,
220
- "content_type": image.content_type,
221
- "data": contents
222
- }
223
- result = await source_images_collection.insert_one(doc)
224
- logger.info(f"Source image uploaded with ID: {str(result.inserted_id)}")
225
- return {"source_id": str(result.inserted_id)}
226
 
227
  @fastapi_app.post("/target")
228
  async def upload_target(image: UploadFile = File(...)):
229
- logger.info(f"Uploading target image: {image.filename}")
230
  contents = await image.read()
231
- doc = {
232
- "filename": image.filename,
233
- "content_type": image.content_type,
234
- "data": contents
235
- }
236
- result = await target_images_collection.insert_one(doc)
237
- logger.info(f"Target image uploaded with ID: {str(result.inserted_id)}")
238
- return {"target_id": str(result.inserted_id)}
239
 
 
240
  class FaceSwapRequest(BaseModel):
241
  source_id: str
242
  target_id: str
243
 
244
  @fastapi_app.post("/faceswap")
245
  async def perform_faceswap(request: FaceSwapRequest):
246
- logger.info(f"Starting face swap for source_id: {request.source_id}, target_id: {request.target_id}")
247
- source_doc = await source_images_collection.find_one({"_id": ObjectId(request.source_id)})
248
- if not source_doc:
249
- logger.error(f"Source image not found: {request.source_id}")
250
- raise HTTPException(status_code=404, detail="Source image not found")
251
-
252
- target_doc = await target_images_collection.find_one({"_id": ObjectId(request.target_id)})
253
- if not target_doc:
254
- logger.error(f"Target image not found: {request.target_id}")
255
- raise HTTPException(status_code=404, detail="Target image not found")
256
-
257
- source_array = np.frombuffer(source_doc["data"], np.uint8)
258
- source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR)
259
- if source_bgr is None:
260
- logger.error("Failed to decode source image")
261
- raise HTTPException(status_code=500, detail="Failed to decode source image")
262
- source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)
263
-
264
- target_array = np.frombuffer(target_doc["data"], np.uint8)
265
- target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR)
266
- if target_bgr is None:
267
- logger.error("Failed to decode target image")
268
- raise HTTPException(status_code=500, detail="Failed to decode target image")
269
- target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB)
270
-
271
- final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb)
272
- if err:
273
- logger.error(f"Face swap failed: {err}")
274
- raise HTTPException(status_code=500, detail=err)
275
-
276
- with open(final_path, "rb") as f:
277
- final_bytes = f.read()
278
-
279
- result_doc = {
280
- "source_id": request.source_id,
281
- "target_id": request.target_id,
282
- "filename": "enhanced.png",
283
- "content_type": "image/png",
284
- "data": final_bytes
285
- }
286
- result = await results_collection.insert_one(result_doc)
287
- logger.info(f"Face swap result stored with ID: {str(result.inserted_id)}")
288
- return {"result_id": str(result.inserted_id)}
289
 
 
 
 
 
290
  @fastapi_app.get("/download/{result_id}")
291
  async def download_result(result_id: str):
292
- logger.info(f"Downloading result: {result_id}")
293
- doc = await results_collection.find_one({"_id": ObjectId(result_id)})
294
- if not doc:
295
- logger.error(f"Result not found: {result_id}")
 
 
 
 
 
296
  raise HTTPException(status_code=404, detail="Result not found")
297
- return Response(
298
- content=doc["data"],
299
- media_type=doc["content_type"],
300
- headers={"Content-Disposition": f"attachment; filename={doc['filename']}"}
301
- )
302
 
303
- # Mount Gradio at /gradio
304
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
305
 
306
- # Run the app with Uvicorn
307
  if __name__ == "__main__":
308
- uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
 
15
  from pydantic import BaseModel
16
  from motor.motor_asyncio import AsyncIOMotorClient
17
  from bson.objectid import ObjectId
18
+ from gridfs import AsyncIOMotorGridFSBucket
19
  from gradio import mount_gradio_app
20
  import uvicorn
21
  import logging
22
+ import io
23
 
24
+ # -------------------------------------------------
25
+ # Logging
26
+ # -------------------------------------------------
27
  logging.basicConfig(level=logging.INFO)
28
  logger = logging.getLogger(__name__)
29
 
 
41
  os.makedirs(MODELS_DIR, exist_ok=True)
42
 
43
  # -------------------------------------------------
44
+ # Download models
45
  # -------------------------------------------------
46
  def download_models():
47
  logger.info("Downloading models...")
 
71
  inswapper_path = download_models()
72
 
73
  # -------------------------------------------------
74
+ # Face Analysis + Swapper
75
  # -------------------------------------------------
76
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
77
  logger.info(f"Initializing FaceAnalysis with providers: {providers}")
 
98
  ensure_codeformer()
99
 
100
  # -------------------------------------------------
101
+ # MongoDB + GridFS
102
  # -------------------------------------------------
103
  MONGODB_URL = os.getenv(
104
  "MONGODB_URL",
 
106
  )
107
  client = AsyncIOMotorClient(MONGODB_URL)
108
  database = client.FaceSwap
109
+ fs_bucket = AsyncIOMotorGridFSBucket(database)
110
+ logger.info("MongoDB + GridFS initialized")
 
 
111
 
112
  # -------------------------------------------------
113
  # Lock for face swap
 
115
  swap_lock = threading.Lock()
116
 
117
  # -------------------------------------------------
118
+ # Face Swap Pipeline
119
  # -------------------------------------------------
120
  def face_swap_and_enhance(src_img, tgt_img):
121
  logger.info("Starting face swap and enhancement")
 
127
  os.makedirs(RESULT_DIR, exist_ok=True)
128
 
129
  if not isinstance(src_img, np.ndarray) or not isinstance(tgt_img, np.ndarray):
130
+ return None, None, "Invalid input images"
 
131
 
132
  src_bgr = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
133
  tgt_bgr = cv2.cvtColor(tgt_img, cv2.COLOR_RGB2BGR)
134
 
 
135
  src_faces = face_analysis_app.get(src_bgr)
136
  tgt_faces = face_analysis_app.get(tgt_bgr)
137
  if not src_faces or not tgt_faces:
138
+ return None, None, "Face not detected"
 
139
 
140
+ swapped_path = os.path.join(UPLOAD_DIR, f"swapped_{uuid.uuid4().hex[:8]}.jpg")
 
 
141
  swapped_bgr = swapper.get(tgt_bgr, tgt_faces[0], src_faces[0])
142
  if swapped_bgr is None:
 
143
  return None, None, "❌ Face swap failed"
144
 
145
  cv2.imwrite(swapped_path, swapped_bgr)
 
146
 
147
  cmd = f"python {CODEFORMER_PATH} -w 0.7 --input_path {swapped_path} --output_path {RESULT_DIR} --bg_upsampler realesrgan --face_upsample"
 
148
  result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
149
  if result.returncode != 0:
 
150
  return None, None, f"❌ CodeFormer failed:\n{result.stderr}"
151
 
152
  final_results_dir = os.path.join(RESULT_DIR, "final_results")
 
 
 
 
153
  final_files = [f for f in os.listdir(final_results_dir) if f.endswith(".png")]
154
  if not final_files:
155
+ return None, None, "No enhanced image found"
 
156
 
157
  final_path = os.path.join(final_results_dir, final_files[0])
158
  final_img = cv2.cvtColor(cv2.imread(final_path), cv2.COLOR_BGR2RGB)
 
159
 
160
  return final_img, final_path, ""
161
 
162
  except Exception as e:
 
163
  return None, None, f"❌ Error: {str(e)}"
164
 
165
  # -------------------------------------------------
 
173
  tgt_input = gr.Image(type="numpy", label="Upload Target Image")
174
 
175
  btn = gr.Button("Swap Face")
 
176
  output_img = gr.Image(type="numpy", label="Enhanced Output")
177
  download = gr.File(label="⬇️ Download Enhanced Image")
178
  error_box = gr.Textbox(label="Logs / Errors", interactive=False)
 
196
  async def health():
197
  return {"status": "healthy"}
198
 
199
+ # -------- Upload Endpoints with GridFS --------
200
  @fastapi_app.post("/source")
201
  async def upload_source(image: UploadFile = File(...)):
 
202
  contents = await image.read()
203
+ file_id = await fs_bucket.upload_from_stream(image.filename, contents)
204
+ return {"source_id": str(file_id)}
 
 
 
 
 
 
205
 
206
  @fastapi_app.post("/target")
207
  async def upload_target(image: UploadFile = File(...)):
 
208
  contents = await image.read()
209
+ file_id = await fs_bucket.upload_from_stream(image.filename, contents)
210
+ return {"target_id": str(file_id)}
 
 
 
 
 
 
211
 
212
+ # -------- Faceswap Endpoint --------
213
  class FaceSwapRequest(BaseModel):
214
  source_id: str
215
  target_id: str
216
 
217
  @fastapi_app.post("/faceswap")
218
  async def perform_faceswap(request: FaceSwapRequest):
219
+ try:
220
+ # Read source
221
+ source_stream = await fs_bucket.open_download_stream(ObjectId(request.source_id))
222
+ source_bytes = await source_stream.read()
223
+ source_array = np.frombuffer(source_bytes, np.uint8)
224
+ source_bgr = cv2.imdecode(source_array, cv2.IMREAD_COLOR)
225
+ source_rgb = cv2.cvtColor(source_bgr, cv2.COLOR_BGR2RGB)
226
+
227
+ # Read target
228
+ target_stream = await fs_bucket.open_download_stream(ObjectId(request.target_id))
229
+ target_bytes = await target_stream.read()
230
+ target_array = np.frombuffer(target_bytes, np.uint8)
231
+ target_bgr = cv2.imdecode(target_array, cv2.IMREAD_COLOR)
232
+ target_rgb = cv2.cvtColor(target_bgr, cv2.COLOR_BGR2RGB)
233
+
234
+ # Run pipeline
235
+ final_img, final_path, err = face_swap_and_enhance(source_rgb, target_rgb)
236
+ if err:
237
+ raise HTTPException(status_code=500, detail=err)
238
+
239
+ # Store result in GridFS
240
+ with open(final_path, "rb") as f:
241
+ final_bytes = f.read()
242
+ result_id = await fs_bucket.upload_from_stream("enhanced.png", final_bytes)
243
+
244
+ return {"result_id": str(result_id)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ except Exception as e:
247
+ raise HTTPException(status_code=500, detail=str(e))
248
+
249
+ # -------- Download Endpoint --------
250
  @fastapi_app.get("/download/{result_id}")
251
  async def download_result(result_id: str):
252
+ try:
253
+ stream = await fs_bucket.open_download_stream(ObjectId(result_id))
254
+ file_data = await stream.read()
255
+ return Response(
256
+ content=file_data,
257
+ media_type="image/png",
258
+ headers={"Content-Disposition": f"attachment; filename=enhanced.png"}
259
+ )
260
+ except Exception:
261
  raise HTTPException(status_code=404, detail="Result not found")
 
 
 
 
 
262
 
263
+ # Mount Gradio
264
  fastapi_app = mount_gradio_app(fastapi_app, demo, path="/gradio")
265
 
 
266
  if __name__ == "__main__":
267
+ uvicorn.run(fastapi_app, host="0.0.0.0", port=7860)
test.py CHANGED
@@ -1,33 +1,54 @@
1
- # import requests
2
 
3
- # response = requests.get("https://logicgoinfotechspaces-faceswap.hf.space/health")
4
- # print(response.status_code) # Should be 200
5
- # print(response.json()) # Should print: {'status': 'healthy'}
6
 
7
- # import requests
 
 
 
8
 
9
- # with open("./source.jpeg", "rb") as f:
10
- # files = {"image": ("./source.jpeg", f, "image/jpeg")}
11
- # response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/source", files=files)
12
- # print(response.status_code) # Should be 200
13
- # print(response.json()) # Should print: {'source_id': '<id>'}
14
- # source_id = response.json().get("source_id")
15
 
16
- # import requests
 
 
 
 
 
 
17
 
18
- # with open("./target.jpeg", "rb") as f:
19
- # files = {"image": ("./target.jpeg", f, "image/jpeg")}
20
- # response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/target", files=files)
21
- # print(response.status_code) # Should be 200
22
- # print(response.json()) # Should print: {'target_id': '<id>'}
23
- # target_id = response.json().get("target_id")
24
 
25
- import requests
 
 
 
 
 
 
26
 
 
 
 
27
  source_id = "68bff17eb3b2e06c24cafa22" # Replace with actual ID
28
  target_id = "68bff1a1b3b2e06c24cafa23" # Replace with actual ID
29
  payload = {"source_id": source_id, "target_id": target_id}
30
  response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/faceswap", json=payload)
31
  print(response.status_code) # Should be 200
32
  print(response.json()) # Should print: {'result_id': '<id>'}
33
- result_id = response.json().get("result_id")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
 
 
 
 
3
 
4
+ #HEALTH CHECK
5
+ response = requests.get("https://logicgoinfotechspaces-faceswap.hf.space/health")
6
+ print(response.status_code) # Should be 200
7
+ print(response.json()) # Should print: {'status': 'healthy'}
8
 
9
+ #----------------------------------------------------------------------------------------------------#
 
 
 
 
 
10
 
11
+ #UPLOAD SOURCE IMAGE
12
+ with open("./source.jpeg", "rb") as f:
13
+ files = {"image": ("./source.jpeg", f, "image/jpeg")}
14
+ response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/source", files=files)
15
+ print(response.status_code) # Should be 200
16
+ print(response.json()) # Should print: {'source_id': '<id>'}
17
+ source_id = response.json().get("source_id")
18
 
19
+ #----------------------------------------------------------------------------------------------------#
 
 
 
 
 
20
 
21
+ #UPLOAD TARGET IMAGE
22
+ with open("./target.jpeg", "rb") as f:
23
+ files = {"image": ("./target.jpeg", f, "image/jpeg")}
24
+ response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/target", files=files)
25
+ print(response.status_code) # Should be 200
26
+ print(response.json()) # Should print: {'target_id': '<id>'}
27
+ target_id = response.json().get("target_id")
28
 
29
+ #----------------------------------------------------------------------------------------------------#
30
+
31
+ #SWAP IMAGES
32
  source_id = "68bff17eb3b2e06c24cafa22" # Replace with actual ID
33
  target_id = "68bff1a1b3b2e06c24cafa23" # Replace with actual ID
34
  payload = {"source_id": source_id, "target_id": target_id}
35
  response = requests.post("https://logicgoinfotechspaces-faceswap.hf.space/faceswap", json=payload)
36
  print(response.status_code) # Should be 200
37
  print(response.json()) # Should print: {'result_id': '<id>'}
38
+ result_id = response.json().get("result_id")
39
+
40
+ #----------------------------------------------------------------------------------------------------#
41
+
42
+
43
+ #DOWNLOAD IMAGES
44
+ result_id = "68bff4800e80c0bb228b5962" # Replace with actual ID
45
+ response = requests.get(f"https://logicgoinfotechspaces-faceswap.hf.space/download/{result_id}")
46
+ print(response.status_code) # Should be 200
47
+ if response.status_code == 200:
48
+ with open("result.png", "wb") as f:
49
+ f.write(response.content)
50
+ print("Image downloaded as result.png")
51
+ else:
52
+ print(response.json()) # Print error details
53
+
54
+ #----------------------------------------------------------------------------------------------------#