affecto-inference / preprocessor.py
gauravvjhaa's picture
slight e
ef2c483
import os
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from insightface.app import FaceAnalysis
from huggingface_hub import hf_hub_download
from utils.bisenet import BiSeNet
class FacePreprocessor:
def __init__(self, device='cuda'):
self.device = device if torch.cuda.is_available() else 'cpu'
print(f"πŸ”§ Initializing Face Preprocessor on: {self.device}")
# Initialize InsightFace with better error handling
print(" - Loading InsightFace...")
try:
self.app = FaceAnalysis(
name='antelopev2',
root='./third_party_files',
providers=['CPUExecutionProvider']
)
self.app.prepare(ctx_id=0, det_size=(640, 640))
print(" βœ… InsightFace loaded successfully")
except AssertionError as e:
print(" ⚠️ InsightFace antelopev2 failed, trying buffalo_l...")
try:
# Fallback to buffalo_l model
self.app = FaceAnalysis(
name='buffalo_l',
root='./third_party_files',
providers=['CPUExecutionProvider']
)
self.app.prepare(ctx_id=0, det_size=(640, 640))
print(" βœ… InsightFace buffalo_l loaded successfully")
except Exception as e2:
print(f" ❌ Both InsightFace models failed: {str(e2)}")
raise RuntimeError("Could not load InsightFace models") from e2
# Initialize BiSeNet for segmentation
print(" - Loading BiSeNet for background extraction...")
self.bisenet = BiSeNet(n_classes=19).to(self.device)
# Download BiSeNet weights from YOUR affecto-faceparsing repo
print(" - Downloading BiSeNet weights from HuggingFace...")
try:
bisenet_path = hf_hub_download(
repo_id="gauravvjhaa/affecto-faceparsing", # ← YOUR REPO
filename="79999_iter.pth",
cache_dir="./models"
)
print(f" - Loading BiSeNet from: {bisenet_path}")
self.bisenet.load_state_dict(
torch.load(bisenet_path, map_location=self.device)
)
print(" βœ… BiSeNet loaded successfully!")
except Exception as e:
print(f" ⚠️ BiSeNet download error: {str(e)}")
print(f" ⚠️ Error details: {repr(e)}")
print(" ⚠️ Using random initialization (background extraction will be poor)")
self.bisenet.eval()
# Transforms
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("βœ… Preprocessor ready!")
def detect_and_crop_face(self, image):
"""Detect face and crop to 512x512"""
print(" πŸ“Έ Detecting face...")
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
try:
faces = self.app.get(img_cv)
except Exception as e:
print(f" ⚠️ Face detection error: {str(e)}")
raise ValueError(f"Face detection failed: {str(e)}")
if len(faces) == 0:
raise ValueError("❌ No face detected!")
if len(faces) > 1:
print(" ⚠️ Multiple faces, using largest")
face = sorted(faces, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]))[-1]
bbox = face['bbox']
x1, y1, x2, y2 = bbox
w = x2 - x1
h = y2 - y1
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
size = max(w, h) * 1.5
x1_new = int(center_x - size / 2)
y1_new = int(center_y - size / 2)
x2_new = int(center_x + size / 2)
y2_new = int(center_y + size / 2)
# Handle boundaries
h_img, w_img = img_cv.shape[:2]
x1_new = max(0, x1_new)
y1_new = max(0, y1_new)
x2_new = min(w_img, x2_new)
y2_new = min(h_img, y2_new)
img_crop = img_cv[y1_new:y2_new, x1_new:x2_new]
# Check if crop is valid
if img_crop.size == 0:
raise ValueError("Invalid crop region")
img_crop_512 = cv2.resize(img_crop, (512, 512), interpolation=cv2.INTER_LINEAR)
img_pil = Image.fromarray(cv2.cvtColor(img_crop_512, cv2.COLOR_BGR2RGB))
print(" βœ… Face cropped to 512x512")
return img_pil
def extract_background(self, image):
"""Extract background using BiSeNet"""
print(" 🎨 Extracting background...")
img_tensor = self.to_tensor(image).unsqueeze(0).to(self.device)
with torch.no_grad():
out = self.bisenet(img_tensor)[0]
parsing = out.squeeze(0).cpu().numpy().argmax(0)
img_np = np.array(image)
mask = np.zeros_like(img_np)
# Set face regions to black (classes 1-18 are face parts)
for i in range(1, 19):
mask[parsing == i] = [0, 0, 0]
# Background is class 0
mask[parsing == 0] = img_np[parsing == 0]
bg_pil = Image.fromarray(mask)
print(" βœ… Background extracted")
return bg_pil
def preprocess(self, image):
"""Full preprocessing pipeline"""
try:
cropped = self.detect_and_crop_face(image)
bg = self.extract_background(cropped)
return cropped, bg
except Exception as e:
print(f"❌ Preprocessing error: {str(e)}")
print(" ⚠️ Using fallback (simple resize)")
img_resized = image.resize((512, 512), Image.LANCZOS)
from PIL import ImageFilter
bg = img_resized.filter(ImageFilter.GaussianBlur(radius=10))
return img_resized, bg