object_remover / src /core.py
LogicGoInfotechSpaces's picture
Ensure mask matches resized image size
d7d0150
raw
history blame
18.7 kB
import base64
import json
import os
import re
import time
import uuid
from io import BytesIO
from pathlib import Path
import cv2
# For inpainting
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
from streamlit_drawable_canvas import st_canvas
import argparse
import io
import multiprocessing
from typing import Union
import torch
try:
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
except:
pass
from src.helper import (
download_model,
load_img,
norm_img,
numpy_to_bytes,
pad_img_to_modulo,
resize_max_size,
)
NUM_THREADS = str(multiprocessing.cpu_count())
os.environ["OMP_NUM_THREADS"] = NUM_THREADS
os.environ["OPENBLAS_NUM_THREADS"] = NUM_THREADS
os.environ["MKL_NUM_THREADS"] = NUM_THREADS
os.environ["VECLIB_MAXIMUM_THREADS"] = NUM_THREADS
os.environ["NUMEXPR_NUM_THREADS"] = NUM_THREADS
if os.environ.get("CACHE_DIR"):
os.environ["TORCH_HOME"] = os.environ["CACHE_DIR"]
#BUILD_DIR = os.environ.get("LAMA_CLEANER_BUILD_DIR", "./lama_cleaner/app/build")
# For Seam-carving
from scipy import ndimage as ndi
SEAM_COLOR = np.array([255, 200, 200]) # seam visualization color (BGR)
SHOULD_DOWNSIZE = True # if True, downsize image for faster carving
DOWNSIZE_WIDTH = 500 # resized image width if SHOULD_DOWNSIZE is True
ENERGY_MASK_CONST = 100000.0 # large energy value for protective masking
MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
device_str = os.environ.get("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(device_str)
model_path = "./assets/big-lama.pt"
model = torch.jit.load(model_path, map_location=device)
model = model.to(device)
model.eval()
########################################
# UTILITY CODE
########################################
def visualize(im, boolmask=None, rotate=False):
vis = im.astype(np.uint8)
if boolmask is not None:
vis[np.where(boolmask == False)] = SEAM_COLOR
if rotate:
vis = rotate_image(vis, False)
cv2.imshow("visualization", vis)
cv2.waitKey(1)
return vis
def resize(image, width):
dim = None
h, w = image.shape[:2]
dim = (width, int(h * width / float(w)))
image = image.astype('float32')
return cv2.resize(image, dim)
def rotate_image(image, clockwise):
k = 1 if clockwise else 3
return np.rot90(image, k)
########################################
# ENERGY FUNCTIONS
########################################
def backward_energy(im):
"""
Simple gradient magnitude energy map.
"""
xgrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=1, mode='wrap')
ygrad = ndi.convolve1d(im, np.array([1, 0, -1]), axis=0, mode='wrap')
grad_mag = np.sqrt(np.sum(xgrad**2, axis=2) + np.sum(ygrad**2, axis=2))
# vis = visualize(grad_mag)
# cv2.imwrite("backward_energy_demo.jpg", vis)
return grad_mag
def forward_energy(im):
"""
Forward energy algorithm as described in "Improved Seam Carving for Video Retargeting"
by Rubinstein, Shamir, Avidan.
Vectorized code adapted from
https://github.com/axu2/improved-seam-carving.
"""
h, w = im.shape[:2]
im = cv2.cvtColor(im.astype(np.uint8), cv2.COLOR_BGR2GRAY).astype(np.float64)
energy = np.zeros((h, w))
m = np.zeros((h, w))
U = np.roll(im, 1, axis=0)
L = np.roll(im, 1, axis=1)
R = np.roll(im, -1, axis=1)
cU = np.abs(R - L)
cL = np.abs(U - L) + cU
cR = np.abs(U - R) + cU
for i in range(1, h):
mU = m[i-1]
mL = np.roll(mU, 1)
mR = np.roll(mU, -1)
mULR = np.array([mU, mL, mR])
cULR = np.array([cU[i], cL[i], cR[i]])
mULR += cULR
argmins = np.argmin(mULR, axis=0)
m[i] = np.choose(argmins, mULR)
energy[i] = np.choose(argmins, cULR)
# vis = visualize(energy)
# cv2.imwrite("forward_energy_demo.jpg", vis)
return energy
########################################
# SEAM HELPER FUNCTIONS
########################################
def add_seam(im, seam_idx):
"""
Add a vertical seam to a 3-channel color image at the indices provided
by averaging the pixels values to the left and right of the seam.
Code adapted from https://github.com/vivianhylee/seam-carving.
"""
h, w = im.shape[:2]
output = np.zeros((h, w + 1, 3))
for row in range(h):
col = seam_idx[row]
for ch in range(3):
if col == 0:
p = np.mean(im[row, col: col + 2, ch])
output[row, col, ch] = im[row, col, ch]
output[row, col + 1, ch] = p
output[row, col + 1:, ch] = im[row, col:, ch]
else:
p = np.mean(im[row, col - 1: col + 1, ch])
output[row, : col, ch] = im[row, : col, ch]
output[row, col, ch] = p
output[row, col + 1:, ch] = im[row, col:, ch]
return output
def add_seam_grayscale(im, seam_idx):
"""
Add a vertical seam to a grayscale image at the indices provided
by averaging the pixels values to the left and right of the seam.
"""
h, w = im.shape[:2]
output = np.zeros((h, w + 1))
for row in range(h):
col = seam_idx[row]
if col == 0:
p = np.mean(im[row, col: col + 2])
output[row, col] = im[row, col]
output[row, col + 1] = p
output[row, col + 1:] = im[row, col:]
else:
p = np.mean(im[row, col - 1: col + 1])
output[row, : col] = im[row, : col]
output[row, col] = p
output[row, col + 1:] = im[row, col:]
return output
def remove_seam(im, boolmask):
h, w = im.shape[:2]
boolmask3c = np.stack([boolmask] * 3, axis=2)
return im[boolmask3c].reshape((h, w - 1, 3))
def remove_seam_grayscale(im, boolmask):
h, w = im.shape[:2]
return im[boolmask].reshape((h, w - 1))
def get_minimum_seam(im, mask=None, remove_mask=None):
"""
DP algorithm for finding the seam of minimum energy. Code adapted from
https://karthikkaranth.me/blog/implementing-seam-carving-with-python/
"""
h, w = im.shape[:2]
energyfn = forward_energy if USE_FORWARD_ENERGY else backward_energy
M = energyfn(im)
if mask is not None:
M[np.where(mask > MASK_THRESHOLD)] = ENERGY_MASK_CONST
# give removal mask priority over protective mask by using larger negative value
if remove_mask is not None:
M[np.where(remove_mask > MASK_THRESHOLD)] = -ENERGY_MASK_CONST * 100
seam_idx, boolmask = compute_shortest_path(M, im, h, w)
return np.array(seam_idx), boolmask
def compute_shortest_path(M, im, h, w):
backtrack = np.zeros_like(M, dtype=np.int_)
# populate DP matrix
for i in range(1, h):
for j in range(0, w):
if j == 0:
idx = np.argmin(M[i - 1, j:j + 2])
backtrack[i, j] = idx + j
min_energy = M[i-1, idx + j]
else:
idx = np.argmin(M[i - 1, j - 1:j + 2])
backtrack[i, j] = idx + j - 1
min_energy = M[i - 1, idx + j - 1]
M[i, j] += min_energy
# backtrack to find path
seam_idx = []
boolmask = np.ones((h, w), dtype=np.bool_)
j = np.argmin(M[-1])
for i in range(h-1, -1, -1):
boolmask[i, j] = False
seam_idx.append(j)
j = backtrack[i, j]
seam_idx.reverse()
return seam_idx, boolmask
########################################
# MAIN ALGORITHM
########################################
def seams_removal(im, num_remove, mask=None, vis=False, rot=False):
for _ in range(num_remove):
seam_idx, boolmask = get_minimum_seam(im, mask)
if vis:
visualize(im, boolmask, rotate=rot)
im = remove_seam(im, boolmask)
if mask is not None:
mask = remove_seam_grayscale(mask, boolmask)
return im, mask
def seams_insertion(im, num_add, mask=None, vis=False, rot=False):
seams_record = []
temp_im = im.copy()
temp_mask = mask.copy() if mask is not None else None
for _ in range(num_add):
seam_idx, boolmask = get_minimum_seam(temp_im, temp_mask)
if vis:
visualize(temp_im, boolmask, rotate=rot)
seams_record.append(seam_idx)
temp_im = remove_seam(temp_im, boolmask)
if temp_mask is not None:
temp_mask = remove_seam_grayscale(temp_mask, boolmask)
seams_record.reverse()
for _ in range(num_add):
seam = seams_record.pop()
im = add_seam(im, seam)
if vis:
visualize(im, rotate=rot)
if mask is not None:
mask = add_seam_grayscale(mask, seam)
# update the remaining seam indices
for remaining_seam in seams_record:
remaining_seam[np.where(remaining_seam >= seam)] += 2
return im, mask
########################################
# MAIN DRIVER FUNCTIONS
########################################
def seam_carve(im, dy, dx, mask=None, vis=False):
im = im.astype(np.float64)
h, w = im.shape[:2]
assert h + dy > 0 and w + dx > 0 and dy <= h and dx <= w
if mask is not None:
mask = mask.astype(np.float64)
output = im
if dx < 0:
output, mask = seams_removal(output, -dx, mask, vis)
elif dx > 0:
output, mask = seams_insertion(output, dx, mask, vis)
if dy < 0:
output = rotate_image(output, True)
if mask is not None:
mask = rotate_image(mask, True)
output, mask = seams_removal(output, -dy, mask, vis, rot=True)
output = rotate_image(output, False)
elif dy > 0:
output = rotate_image(output, True)
if mask is not None:
mask = rotate_image(mask, True)
output, mask = seams_insertion(output, dy, mask, vis, rot=True)
output = rotate_image(output, False)
return output
def object_removal(im, rmask, mask=None, vis=False, horizontal_removal=False):
im = im.astype(np.float64)
rmask = rmask.astype(np.float64)
if mask is not None:
mask = mask.astype(np.float64)
output = im
h, w = im.shape[:2]
if horizontal_removal:
output = rotate_image(output, True)
rmask = rotate_image(rmask, True)
if mask is not None:
mask = rotate_image(mask, True)
while len(np.where(rmask > MASK_THRESHOLD)[0]) > 0:
seam_idx, boolmask = get_minimum_seam(output, mask, rmask)
if vis:
visualize(output, boolmask, rotate=horizontal_removal)
output = remove_seam(output, boolmask)
rmask = remove_seam_grayscale(rmask, boolmask)
if mask is not None:
mask = remove_seam_grayscale(mask, boolmask)
num_add = (h if horizontal_removal else w) - output.shape[1]
output, mask = seams_insertion(output, num_add, mask, vis, rot=horizontal_removal)
if horizontal_removal:
output = rotate_image(output, False)
return output
def s_image(im,mask,vs,hs,mode="resize"):
im = cv2.cvtColor(im, cv2.COLOR_RGBA2RGB)
mask = 255-mask[:,:,3]
h, w = im.shape[:2]
if SHOULD_DOWNSIZE and w > DOWNSIZE_WIDTH:
im = resize(im, width=DOWNSIZE_WIDTH)
if mask is not None:
mask = resize(mask, width=DOWNSIZE_WIDTH)
# image resize mode
if mode=="resize":
dy = hs#reverse
dx = vs#reverse
assert dy is not None and dx is not None
output = seam_carve(im, dy, dx, mask, False)
# object removal mode
elif mode=="remove":
assert mask is not None
output = object_removal(im, mask, None, False, True)
return output
##### Inpainting helper code
def run(image, mask):
"""
image: [C, H, W]
mask: [1, H, W]
return: BGR IMAGE
"""
origin_height, origin_width = image.shape[1:]
image = pad_img_to_modulo(image, mod=8)
mask = pad_img_to_modulo(mask, mod=8)
mask = (mask > 0) * 1
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
start = time.time()
with torch.no_grad():
inpainted_image = model(image, mask)
print(f"process time: {(time.time() - start)*1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
return cur_res
def get_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--port", default=8080, type=int)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def process_inpaint(image, mask, invert_mask=True):
"""
Process inpainting - handles both alpha-based masks and RGB-based masks.
Preserves original image quality and dimensions.
Reference: https://huggingface.co/spaces/aryadytm/remove-photo-object
"""
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
original_shape = image.shape # (H, W, C)
interpolation = cv2.INTER_CUBIC
# Preserve original size - only resize if absolutely necessary for memory/performance
# Keep original quality by preserving dimensions
max_dimension = max(image.shape[:2])
# Don't resize unless image is extremely large (over 3000px) to preserve quality
if max_dimension > 3000:
size_limit = 3000
print(f"Very large image detected ({max_dimension}px), resizing to {size_limit}px for processing")
else:
size_limit = max_dimension # Keep original size to preserve quality
print(f"Preserving original image size: {max_dimension}px (no resize)")
print(f"Origin image shape: {original_shape}")
# Resize image only if needed
if size_limit < max_dimension:
image_resized = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
print(f"Resized image shape: {image_resized.shape}")
else:
image_resized = image
print(f"Image not resized: {image_resized.shape}")
image = norm_img(image_resized)
# Handle mask: check if we should use alpha channel or RGB channels
alpha_channel = mask[:,:,3]
rgb_channels = mask[:,:,:3]
# Check if alpha is meaningful (not all 255)
alpha_mean = alpha_channel.mean()
if alpha_mean < 240:
# Alpha channel is meaningful (has transparent areas)
# Reference model logic: mask = 255-mask[:,:,3]
# alpha=0 (transparent) → 255 (white/remove)
# alpha=255 (opaque) → 0 (black/keep)
mask = 255 - alpha_channel
transparent_count = int((alpha_channel < 128).sum())
print(f"Using alpha channel: {transparent_count} transparent pixels → white (to remove)")
# For alpha-based masks: invert_mask=True means keep current (white=remove is correct)
# invert_mask=False means invert (white becomes black)
if not invert_mask:
mask = 255 - mask
print(f"Applied invert_mask=False: inverted alpha-based mask")
else:
# Alpha is mostly opaque (255), use RGB channels instead
# RGB masks: white (255) = remove, black (0) = keep (standard convention)
gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
mask = (gray > 128).astype(np.uint8) * 255
white_count = int((mask > 128).sum())
print(f"Using RGB channels: {white_count} white pixels (to remove)")
# For RGB-based masks: white=remove is already correct
# invert_mask=False means we want black=remove (invert it)
if not invert_mask:
mask = 255 - mask # invert: white becomes black, black becomes white
print(f"Applied invert_mask=False: inverted RGB mask (now black=remove)")
# Resize mask to match image dimensions (always force exact match)
target_h, target_w = image_resized.shape[:2]
if mask.shape[:2] != (target_h, target_w):
mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)
# Debug: log final mask statistics
mask_nonzero = int((mask > 128).sum())
mask_total = mask.shape[0] * mask.shape[1]
print(f"Final mask before normalization: {mask_nonzero}/{mask_total} pixels marked for removal ({100*mask_nonzero/mask_total:.2f}%)")
if mask_nonzero < 10:
print("ERROR: Mask is empty or almost empty! Returning original image.")
# Return original image at original size
original_rgb = (image_resized * 255).astype(np.uint8)
return cv2.resize(cv2.cvtColor(original_rgb, cv2.COLOR_RGB2BGR),
(original_shape[1], original_shape[0]),
interpolation=cv2.INTER_CUBIC)
# Verify mask is correct before normalization
print(f"Mask verification: {mask_nonzero} pixels will be removed, shape: {mask.shape}")
mask = norm_img(mask)
# Verify normalized mask
mask_normalized_ones = int((mask > 0.5).sum())
print(f"After normalization: {mask_normalized_ones} pixels marked for removal (value > 0.5)")
# Run inpainting
print("Running LaMa model for inpainting...")
res_np_img = run(image, mask)
print(f"Inpainting complete. Output shape: {res_np_img.shape}")
# Verify output changed
original_for_compare = (image_resized * 255).astype(np.uint8)
original_bgr = cv2.cvtColor(original_for_compare, cv2.COLOR_RGB2BGR)
diff = np.abs(res_np_img.astype(np.float32) - original_bgr.astype(np.float32))
diff_pixels = int((diff.sum(axis=2) > 10).sum()) # Pixels that changed by more than 10 in any channel
print(f"Output verification: {diff_pixels} pixels differ from input (should be > 0 if inpainting worked)")
# Resize back to original dimensions if we resized (use LANCZOS4 for better quality)
if size_limit < max_dimension:
res_np_img = cv2.resize(res_np_img, (original_shape[1], original_shape[0]),
interpolation=cv2.INTER_LANCZOS4)
print(f"Resized output back to original size: {res_np_img.shape}")
return cv2.cvtColor(res_np_img, cv2.COLOR_BGR2RGB)