LogicGoInfotechSpaces commited on
Commit
89c8105
·
1 Parent(s): ed7d157

fix: improve image quality with LANCZOS4 interpolation, better mask validation, and high-quality PNG output

Browse files
Files changed (2) hide show
  1. api/main.py +7 -7
  2. src/core.py +31 -5
api/main.py CHANGED
@@ -204,7 +204,7 @@ def inpaint(req: InpaintRequest, _: None = Depends(bearer_auth)) -> Dict[str, st
204
  result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
205
  result_name = f"output_{uuid.uuid4().hex}.png"
206
  result_path = os.path.join(OUTPUT_DIR, result_name)
207
- Image.fromarray(result).save(result_path)
208
 
209
  logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()})
210
  return {"result": result_name}
@@ -228,7 +228,7 @@ def inpaint_url(req: InpaintRequest, request: Request, _: None = Depends(bearer_
228
  result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
229
  result_name = f"output_{uuid.uuid4().hex}.png"
230
  result_path = os.path.join(OUTPUT_DIR, result_name)
231
- Image.fromarray(result).save(result_path)
232
 
233
  url = str(request.url_for("download_file", filename=result_name))
234
  logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()})
@@ -254,7 +254,7 @@ def inpaint_multipart(
254
  result = np.array(img.convert("RGB"))
255
  result_name = f"output_{uuid.uuid4().hex}.png"
256
  result_path = os.path.join(OUTPUT_DIR, result_name)
257
- Image.fromarray(result).save(result_path)
258
 
259
  url: Optional[str] = None
260
  try:
@@ -326,7 +326,7 @@ def inpaint_multipart(
326
  result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB"))
327
  result_name = f"output_{uuid.uuid4().hex}.png"
328
  result_path = os.path.join(OUTPUT_DIR, result_name)
329
- Image.fromarray(result).save(result_path)
330
  return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"}
331
 
332
  # Create black/white mask: white = remove (pink areas), black = keep (everything else)
@@ -349,7 +349,7 @@ def inpaint_multipart(
349
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert)
350
  result_name = f"output_{uuid.uuid4().hex}.png"
351
  result_path = os.path.join(OUTPUT_DIR, result_name)
352
- Image.fromarray(result).save(result_path)
353
 
354
  url: Optional[str] = None
355
  try:
@@ -417,7 +417,7 @@ def remove_pink_segments(
417
  result = np.array(img.convert("RGB"))
418
  result_name = f"output_{uuid.uuid4().hex}.png"
419
  result_path = os.path.join(OUTPUT_DIR, result_name)
420
- Image.fromarray(result).save(result_path)
421
  return {
422
  "result": result_name,
423
  "error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)."
@@ -439,7 +439,7 @@ def remove_pink_segments(
439
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=False)
440
  result_name = f"output_{uuid.uuid4().hex}.png"
441
  result_path = os.path.join(OUTPUT_DIR, result_name)
442
- Image.fromarray(result).save(result_path)
443
 
444
  url: Optional[str] = None
445
  try:
 
204
  result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
205
  result_name = f"output_{uuid.uuid4().hex}.png"
206
  result_path = os.path.join(OUTPUT_DIR, result_name)
207
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
208
 
209
  logs.append({"result": result_name, "timestamp": datetime.utcnow().isoformat()})
210
  return {"result": result_name}
 
228
  result = process_inpaint(np.array(img_rgba), mask_rgba, invert_mask=req.invert_mask)
229
  result_name = f"output_{uuid.uuid4().hex}.png"
230
  result_path = os.path.join(OUTPUT_DIR, result_name)
231
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
232
 
233
  url = str(request.url_for("download_file", filename=result_name))
234
  logs.append({"result": result_name, "url": url, "timestamp": datetime.utcnow().isoformat()})
 
254
  result = np.array(img.convert("RGB"))
255
  result_name = f"output_{uuid.uuid4().hex}.png"
256
  result_path = os.path.join(OUTPUT_DIR, result_name)
257
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
258
 
259
  url: Optional[str] = None
260
  try:
 
326
  result = np.array(img.convert("RGB")) if img else np.array(m.convert("RGB"))
327
  result_name = f"output_{uuid.uuid4().hex}.png"
328
  result_path = os.path.join(OUTPUT_DIR, result_name)
329
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
330
  return {"result": result_name, "error": "pink/magenta paint detection failed - very few pixels detected"}
331
 
332
  # Create black/white mask: white = remove (pink areas), black = keep (everything else)
 
349
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=actual_invert)
350
  result_name = f"output_{uuid.uuid4().hex}.png"
351
  result_path = os.path.join(OUTPUT_DIR, result_name)
352
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
353
 
354
  url: Optional[str] = None
355
  try:
 
417
  result = np.array(img.convert("RGB"))
418
  result_name = f"output_{uuid.uuid4().hex}.png"
419
  result_path = os.path.join(OUTPUT_DIR, result_name)
420
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
421
  return {
422
  "result": result_name,
423
  "error": "No pink/magenta segments detected. Please paint areas to remove with magenta/pink color (RGB 255,0,255)."
 
439
  result = process_inpaint(np.array(img), mask_rgba, invert_mask=False)
440
  result_name = f"output_{uuid.uuid4().hex}.png"
441
  result_path = os.path.join(OUTPUT_DIR, result_name)
442
+ Image.fromarray(result).save(result_path, "PNG", optimize=False, compress_level=1)
443
 
444
  url: Optional[str] = None
445
  try:
src/core.py CHANGED
@@ -449,11 +449,16 @@ def process_inpaint(image, mask, invert_mask=True):
449
  """
450
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
451
  original_shape = image.shape
452
- interpolation = cv2.INTER_CUBIC
 
453
 
454
- size_limit = max(image.shape)
 
 
 
455
 
456
  print(f"Origin image shape: {original_shape}")
 
457
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
458
  print(f"Resized image shape: {image.shape}")
459
  image = norm_img(image)
@@ -489,13 +494,25 @@ def process_inpaint(image, mask, invert_mask=True):
489
  if not invert_mask:
490
  mask = 255 - mask # double invert back to original
491
 
492
- mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
 
493
 
494
- # Debug: log mask statistics
495
  mask_nonzero = int((mask > 128).sum())
496
  mask_total = mask.shape[0] * mask.shape[1]
497
  print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
498
 
 
 
 
 
 
 
 
 
 
 
 
499
  # Normalize: values > 0 become 1.0, 0 stays 0 (LaMa expects this)
500
  mask = norm_img(mask)
501
 
@@ -504,8 +521,17 @@ def process_inpaint(image, mask, invert_mask=True):
504
  print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
505
 
506
  if mask_final_pixels < 10:
507
- print("WARNING: Very few pixels marked for removal! Check mask format.")
 
 
 
508
 
509
  res_np_img = run(image, mask)
510
 
 
 
 
 
 
 
511
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)
 
449
  """
450
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
451
  original_shape = image.shape
452
+ # Use INTER_LANCZOS4 for better quality (higher quality interpolation)
453
+ interpolation = cv2.INTER_LANCZOS4
454
 
455
+ # Increase size limit to preserve quality (up to 2048px max dimension)
456
+ # Reference model uses max(image.shape) but we can optimize for quality
457
+ max_dimension = max(image.shape)
458
+ size_limit = min(max_dimension, 2048) # Cap at 2048 for quality/speed balance
459
 
460
  print(f"Origin image shape: {original_shape}")
461
+ print(f"Size limit: {size_limit} (max dimension was {max_dimension})")
462
  image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
463
  print(f"Resized image shape: {image.shape}")
464
  image = norm_img(image)
 
494
  if not invert_mask:
495
  mask = 255 - mask # double invert back to original
496
 
497
+ # Resize mask to match image dimensions (use INTER_NEAREST for binary mask)
498
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=cv2.INTER_NEAREST)
499
 
500
+ # Debug: log mask statistics BEFORE normalization
501
  mask_nonzero = int((mask > 128).sum())
502
  mask_total = mask.shape[0] * mask.shape[1]
503
  print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
504
 
505
+ if mask_nonzero < 10:
506
+ print("ERROR: Mask is empty or almost empty! Cannot proceed with inpainting.")
507
+ print("DEBUG INFO:")
508
+ print(f" - Alpha channel mean: {alpha_mean}")
509
+ print(f" - RGB channels min/max: {rgb_channels.min()}/{rgb_channels.max()}")
510
+ print(f" - Alpha channel min/max: {alpha_channel.min()}/{alpha_channel.max()}")
511
+ # Return original image if mask is invalid
512
+ return cv2.cvtColor(cv2.resize(cv2.cvtColor(np.array(image*255, dtype=np.uint8), cv2.COLOR_RGB2BGR),
513
+ (original_shape[1], original_shape[0]),
514
+ interpolation=cv2.INTER_LANCZOS4), cv2.COLOR_BGR2RGB)
515
+
516
  # Normalize: values > 0 become 1.0, 0 stays 0 (LaMa expects this)
517
  mask = norm_img(mask)
518
 
 
521
  print(f"After normalization: {mask_final_pixels} pixels marked for removal (value > 0.5)")
522
 
523
  if mask_final_pixels < 10:
524
+ print("ERROR: After normalization, mask is still empty! Returning original image.")
525
+ return cv2.cvtColor(cv2.resize(cv2.cvtColor(np.array(image*255, dtype=np.uint8), cv2.COLOR_RGB2BGR),
526
+ (original_shape[1], original_shape[0]),
527
+ interpolation=cv2.INTER_LANCZOS4), cv2.COLOR_BGR2RGB)
528
 
529
  res_np_img = run(image, mask)
530
 
531
+ # Resize back to original dimensions if needed (for quality preservation)
532
+ if res_np_img.shape[:2] != original_shape[:2]:
533
+ res_np_img = cv2.resize(res_np_img, (original_shape[1], original_shape[0]),
534
+ interpolation=cv2.INTER_LANCZOS4)
535
+ print(f"Resized output back to original: {res_np_img.shape}")
536
+
537
  return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)