Spaces:
Running
on
T4
Running
on
T4
Commit
·
20727d7
1
Parent(s):
e7611db
fix(mask): ensure proper mask interpretation - selected areas are removed; add detailed logging
Browse files- api/main.py +16 -5
- src/core.py +17 -5
api/main.py
CHANGED
|
@@ -125,29 +125,40 @@ def _load_rgba_image(path: str) -> Image.Image:
|
|
| 125 |
|
| 126 |
|
| 127 |
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if img.mode != "RGBA":
|
| 131 |
# For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
|
| 132 |
gray = img.convert("L")
|
| 133 |
arr = np.array(gray)
|
| 134 |
-
# White pixels (>128) should have alpha=0 (to remove
|
|
|
|
| 135 |
alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
|
| 136 |
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
|
| 137 |
rgba[:, :, 3] = alpha
|
|
|
|
| 138 |
return rgba
|
| 139 |
-
|
|
|
|
| 140 |
arr = np.array(img)
|
| 141 |
alpha = arr[:, :, 3]
|
| 142 |
-
|
|
|
|
| 143 |
if alpha.mean() > 200:
|
| 144 |
# Use RGB to determine mask: white in RGB = remove
|
| 145 |
gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
|
| 146 |
alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
|
| 147 |
rgba = arr.copy()
|
| 148 |
rgba[:, :, 3] = alpha
|
|
|
|
| 149 |
return rgba
|
|
|
|
| 150 |
# Alpha channel already encodes the mask
|
|
|
|
| 151 |
return arr
|
| 152 |
|
| 153 |
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
|
| 128 |
+
"""
|
| 129 |
+
Convert mask image to RGBA format.
|
| 130 |
+
Standard convention: white (255) = area to remove, black (0) = area to keep
|
| 131 |
+
Returns RGBA where alpha=0 means "to remove", alpha=255 means "keep"
|
| 132 |
+
(This will be inverted in process_inpaint if invert_mask=True)
|
| 133 |
+
"""
|
| 134 |
if img.mode != "RGBA":
|
| 135 |
# For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
|
| 136 |
gray = img.convert("L")
|
| 137 |
arr = np.array(gray)
|
| 138 |
+
# White pixels (>128) should have alpha=0 (to remove after inversion)
|
| 139 |
+
# Black pixels (<=128) should have alpha=255 (to keep after inversion)
|
| 140 |
alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
|
| 141 |
rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
|
| 142 |
rgba[:, :, 3] = alpha
|
| 143 |
+
log.info(f"Loaded {img.mode} mask: {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
|
| 144 |
return rgba
|
| 145 |
+
|
| 146 |
+
# For RGBA: check if alpha channel is meaningful
|
| 147 |
arr = np.array(img)
|
| 148 |
alpha = arr[:, :, 3]
|
| 149 |
+
|
| 150 |
+
# If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values
|
| 151 |
if alpha.mean() > 200:
|
| 152 |
# Use RGB to determine mask: white in RGB = remove
|
| 153 |
gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
|
| 154 |
alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
|
| 155 |
rgba = arr.copy()
|
| 156 |
rgba[:, :, 3] = alpha
|
| 157 |
+
log.info(f"Loaded RGBA mask (RGB-based): {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
|
| 158 |
return rgba
|
| 159 |
+
|
| 160 |
# Alpha channel already encodes the mask
|
| 161 |
+
log.info(f"Loaded RGBA mask (alpha-based): {int((alpha < 128).sum())} pixels marked for removal (alpha<128)")
|
| 162 |
return arr
|
| 163 |
|
| 164 |
|
src/core.py
CHANGED
|
@@ -459,16 +459,28 @@ def process_inpaint(image, mask, invert_mask=True):
|
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
-
# Standard:
|
| 463 |
-
#
|
| 464 |
alpha_channel = mask[:,:,3]
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 467 |
|
| 468 |
-
# Debug: log mask statistics
|
| 469 |
mask_nonzero = int((mask > 128).sum())
|
| 470 |
-
|
|
|
|
| 471 |
|
|
|
|
|
|
|
| 472 |
mask = norm_img(mask)
|
| 473 |
|
| 474 |
res_np_img = run(image, mask)
|
|
|
|
| 459 |
image = norm_img(image)
|
| 460 |
|
| 461 |
# Convert RGBA mask to single-channel mask.
|
| 462 |
+
# Standard LaMa convention: 1 = remove, 0 = keep
|
| 463 |
+
# User draws with alpha=0 (transparent), we want those to become 1 (remove)
|
| 464 |
alpha_channel = mask[:,:,3]
|
| 465 |
+
|
| 466 |
+
# When invert_mask=True: alpha=0 (painted/transparent) → 255 → 1 (remove)
|
| 467 |
+
# When invert_mask=False: alpha=255 (opaque) → 255 → 1 (remove)
|
| 468 |
+
if invert_mask:
|
| 469 |
+
# Inverted: transparent (0) means remove, opaque (255) means keep
|
| 470 |
+
mask = 255 - alpha_channel
|
| 471 |
+
else:
|
| 472 |
+
# Normal: opaque (255) means remove, transparent (0) means keep
|
| 473 |
+
mask = alpha_channel
|
| 474 |
+
|
| 475 |
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
| 476 |
|
| 477 |
+
# Debug: log mask statistics BEFORE normalization
|
| 478 |
mask_nonzero = int((mask > 128).sum())
|
| 479 |
+
mask_total = mask.shape[0] * mask.shape[1]
|
| 480 |
+
print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
|
| 481 |
|
| 482 |
+
# Normalize: values > 0 become 1.0, 0 stays 0
|
| 483 |
+
# After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
|
| 484 |
mask = norm_img(mask)
|
| 485 |
|
| 486 |
res_np_img = run(image, mask)
|