LogicGoInfotechSpaces commited on
Commit
20727d7
·
1 Parent(s): e7611db

fix(mask): ensure proper mask interpretation - selected areas are removed; add detailed logging

Browse files
Files changed (2) hide show
  1. api/main.py +16 -5
  2. src/core.py +17 -5
api/main.py CHANGED
@@ -125,29 +125,40 @@ def _load_rgba_image(path: str) -> Image.Image:
125
 
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
- # Standard convention: white=remove (255), black=keep (0)
129
- # Convert to RGBA where alpha=0 means "to remove", alpha=255 means "keep"
 
 
 
 
130
  if img.mode != "RGBA":
131
  # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
132
  gray = img.convert("L")
133
  arr = np.array(gray)
134
- # White pixels (>128) should have alpha=0 (to remove), black pixels (<=128) alpha=255 (keep)
 
135
  alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
136
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
137
  rgba[:, :, 3] = alpha
 
138
  return rgba
139
- # For RGBA: check if alpha channel is used or RGB channels
 
140
  arr = np.array(img)
141
  alpha = arr[:, :, 3]
142
- # If alpha is mostly opaque (mean > 200), treat RGB channels as mask values
 
143
  if alpha.mean() > 200:
144
  # Use RGB to determine mask: white in RGB = remove
145
  gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
146
  alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
147
  rgba = arr.copy()
148
  rgba[:, :, 3] = alpha
 
149
  return rgba
 
150
  # Alpha channel already encodes the mask
 
151
  return arr
152
 
153
 
 
125
 
126
 
127
  def _load_rgba_mask_from_image(img: Image.Image) -> np.ndarray:
128
+ """
129
+ Convert mask image to RGBA format.
130
+ Standard convention: white (255) = area to remove, black (0) = area to keep
131
+ Returns RGBA where alpha=0 means "to remove", alpha=255 means "keep"
132
+ (This will be inverted in process_inpaint if invert_mask=True)
133
+ """
134
  if img.mode != "RGBA":
135
  # For RGB/Grayscale masks: white (value>128) = remove, black (value<=128) = keep
136
  gray = img.convert("L")
137
  arr = np.array(gray)
138
+ # White pixels (>128) should have alpha=0 (to remove after inversion)
139
+ # Black pixels (<=128) should have alpha=255 (to keep after inversion)
140
  alpha = np.where(arr > 128, 0, 255).astype(np.uint8)
141
  rgba = np.zeros((img.height, img.width, 4), dtype=np.uint8)
142
  rgba[:, :, 3] = alpha
143
+ log.info(f"Loaded {img.mode} mask: {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
144
  return rgba
145
+
146
+ # For RGBA: check if alpha channel is meaningful
147
  arr = np.array(img)
148
  alpha = arr[:, :, 3]
149
+
150
+ # If alpha is mostly opaque everywhere (mean > 200), treat RGB channels as mask values
151
  if alpha.mean() > 200:
152
  # Use RGB to determine mask: white in RGB = remove
153
  gray = cv2.cvtColor(arr[:, :, :3], cv2.COLOR_RGB2GRAY)
154
  alpha = np.where(gray > 128, 0, 255).astype(np.uint8)
155
  rgba = arr.copy()
156
  rgba[:, :, 3] = alpha
157
+ log.info(f"Loaded RGBA mask (RGB-based): {int((alpha == 0).sum())} pixels marked for removal (alpha=0)")
158
  return rgba
159
+
160
  # Alpha channel already encodes the mask
161
+ log.info(f"Loaded RGBA mask (alpha-based): {int((alpha < 128).sum())} pixels marked for removal (alpha<128)")
162
  return arr
163
 
164
 
src/core.py CHANGED
@@ -459,16 +459,28 @@ def process_inpaint(image, mask, invert_mask=True):
459
  image = norm_img(image)
460
 
461
  # Convert RGBA mask to single-channel mask.
462
- # Standard: white=remove (255), black=keep (0)
463
- # When invert_mask=True (default): alpha=0 (transparent/painted) 255 (remove), alpha=255 0 (keep)
464
  alpha_channel = mask[:,:,3]
465
- mask = (255 - alpha_channel) if invert_mask else alpha_channel
 
 
 
 
 
 
 
 
 
466
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
467
 
468
- # Debug: log mask statistics
469
  mask_nonzero = int((mask > 128).sum())
470
- print(f"Mask shape: {mask.shape}, non-zero pixels (>128): {mask_nonzero}")
 
471
 
 
 
472
  mask = norm_img(mask)
473
 
474
  res_np_img = run(image, mask)
 
459
  image = norm_img(image)
460
 
461
  # Convert RGBA mask to single-channel mask.
462
+ # Standard LaMa convention: 1 = remove, 0 = keep
463
+ # User draws with alpha=0 (transparent), we want those to become 1 (remove)
464
  alpha_channel = mask[:,:,3]
465
+
466
+ # When invert_mask=True: alpha=0 (painted/transparent) → 255 → 1 (remove)
467
+ # When invert_mask=False: alpha=255 (opaque) → 255 → 1 (remove)
468
+ if invert_mask:
469
+ # Inverted: transparent (0) means remove, opaque (255) means keep
470
+ mask = 255 - alpha_channel
471
+ else:
472
+ # Normal: opaque (255) means remove, transparent (0) means keep
473
+ mask = alpha_channel
474
+
475
  mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
476
 
477
+ # Debug: log mask statistics BEFORE normalization
478
  mask_nonzero = int((mask > 128).sum())
479
+ mask_total = mask.shape[0] * mask.shape[1]
480
+ print(f"Mask shape: {mask.shape}, pixels to remove (>128): {mask_nonzero}/{mask_total} ({100*mask_nonzero/mask_total:.1f}%)")
481
 
482
+ # Normalize: values > 0 become 1.0, 0 stays 0
483
+ # After this, 1.0 = remove, 0.0 = keep (LaMa expects this)
484
  mask = norm_img(mask)
485
 
486
  res_np_img = run(image, mask)