import logging import time from typing import Any, Callable, Dict, Optional, Tuple import cv2 import numpy as np import spaces from PIL import Image logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class GPUHandlers: """ Handles all GPU-intensive generation operations. This class encapsulates the execution logic for both background generation and inpainting operations with proper @spaces.GPU decorator for HuggingFace Spaces deployment. Supports dual-mode inpainting: - Pure Inpainting (use_controlnet=False): For object replacement/removal - ControlNet Inpainting (use_controlnet=True): For clothing/color change """ def __init__( self, core: Any, inpainting_template_manager: Any ): """ Initialize the GPU handlers. Parameters ---------- core : SceneWeaverCore Main engine instance inpainting_template_manager : InpaintingTemplateManager Template manager for inpainting """ self.core = core self.inpainting_template_manager = inpainting_template_manager logger.info("GPUHandlers initialized") @spaces.GPU(duration=240) def background_generate( self, image: Optional[Image.Image], prompt: str, negative_prompt: str, composition_mode: str, focus_mode: str, num_steps: int, guidance_scale: float, progress_callback: Optional[Callable[[str, int], None]] = None ) -> Dict[str, Any]: """ Handle background generation request with GPU access. Parameters ---------- image : PIL.Image, optional Input image prompt : str Generation prompt negative_prompt : str Negative prompt composition_mode : str Composition mode (center, left_half, etc.) focus_mode : str Focus mode (person, scene) num_steps : int Number of inference steps guidance_scale : float Guidance scale progress_callback : callable, optional Progress update function(message, percentage) Returns ------- dict Result dictionary with success status and images """ if image is None: return {"success": False, "error": "Please upload an image first"} if not prompt.strip(): return {"success": False, "error": "Please enter a prompt"} try: logger.info(f"Starting background generation: {prompt[:50]}...") start_time = time.time() # Initialize if needed if not self.core.is_initialized: if progress_callback: progress_callback("Loading AI models...", 5) self.core.load_models(progress_callback=progress_callback) # Generate and combine if progress_callback: progress_callback("Generating background...", 20) result = self.core.generate_and_combine( original_image=image, prompt=prompt, combination_mode=composition_mode, focus_mode=focus_mode, negative_prompt=negative_prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, progress_callback=progress_callback ) elapsed = time.time() - start_time logger.info(f"Background generation complete in {elapsed:.1f}s") return result except Exception as e: error_msg = str(e) logger.error(f"Background generation error: {error_msg}") return {"success": False, "error": error_msg} @spaces.GPU(duration=420) def inpainting_generate( self, image: Optional[Image.Image], mask: Optional[Image.Image], prompt: str, template_key: Optional[str], model_key: str, conditioning_type: str, conditioning_scale: float, feather_radius: int, guidance_scale: float, num_steps: int, seed: int = -1, progress_callback: Optional[Callable[[str, int], None]] = None ) -> Tuple[Optional[Image.Image], Optional[Image.Image], str, int]: """ Handle inpainting request with GPU access. Supports dual-mode operation based on template: - Pure Inpainting: For object_replacement, removal - ControlNet: For clothing_change, change_color Parameters ---------- image : PIL.Image Original image to inpaint mask : PIL.Image Inpainting mask (white = area to regenerate) prompt : str Inpainting prompt template_key : str, optional Template key if using a template model_key : str Model key (juggernaut_xl, realvis_xl, sdxl_base, animagine_xl) conditioning_type : str ControlNet conditioning type (canny/depth) - only for ControlNet mode conditioning_scale : float ControlNet conditioning scale feather_radius : int Mask feather radius guidance_scale : float Generation guidance scale num_steps : int Number of inference steps seed : int Random seed (-1 for random) progress_callback : callable, optional Progress update function Returns ------- tuple (result_image, control_image, status_message, used_seed) """ if image is None: return None, None, "Please upload an image first", -1 if mask is None: return None, None, "Please draw a mask on the image", -1 try: logger.info(f"Starting inpainting: prompt='{prompt[:30]}...', template={template_key}") start_time = time.time() # Get template parameters built_prompt = prompt negative_prompt = "" template_params = {} use_controlnet = True # Default to ControlNet mode if template_key: template = self.inpainting_template_manager.get_template(template_key) if template: # For removal template, use template prompt directly if user prompt is empty if template_key == "removal" and not prompt.strip(): built_prompt = template.prompt_template else: built_prompt = self.inpainting_template_manager.build_prompt(template_key, prompt) negative_prompt = self.inpainting_template_manager.get_negative_prompt(template_key) template_params = self.inpainting_template_manager.get_parameters_for_template(template_key) use_controlnet = template_params.get("use_controlnet", True) logger.info(f"Template: {template_key}, use_controlnet={use_controlnet}") # Build final parameters final_params = { # Pipeline mode "use_controlnet": use_controlnet, "mask_dilation": template_params.get("mask_dilation", 0), # ControlNet parameters (only used if use_controlnet=True) "conditioning_type": template_params.get("preferred_conditioning", conditioning_type), "controlnet_conditioning_scale": template_params.get("controlnet_conditioning_scale", conditioning_scale), "preserve_structure_in_mask": template_params.get("preserve_structure_in_mask", False), "edge_guidance_mode": template_params.get("edge_guidance_mode", "boundary"), # Generation parameters "feather_radius": template_params.get("feather_radius", feather_radius), "guidance_scale": template_params.get("guidance_scale", guidance_scale), "num_inference_steps": template_params.get("num_inference_steps", num_steps), "strength": template_params.get("strength", 0.99), "negative_prompt": negative_prompt, "seed": seed, } # Execute inpainting through core result = self.core.execute_inpainting( image=image, mask=mask, prompt=built_prompt, model_key=model_key, progress_callback=progress_callback, **final_params ) elapsed = time.time() - start_time if result.get('success'): mode_str = "Pure Inpainting" if not use_controlnet else "ControlNet" # Get the actual seed used from metadata used_seed = result.get('metadata', {}).get('seed', seed) status = f"Complete ({mode_str}) in {elapsed:.1f}s | Seed: {used_seed}" return ( result.get('combined_image'), result.get('control_image'), status, used_seed ) else: error_msg = result.get('error', 'Unknown error') return None, None, f"Error: {error_msg}", -1 except Exception as e: error_msg = str(e) logger.error(f"Inpainting handler error: {e}") return None, None, f"Error: {error_msg}", -1 def extract_mask_from_editor(mask_editor: Dict[str, Any]) -> Optional[Image.Image]: """ Extract mask from Gradio ImageEditor component. Parameters ---------- mask_editor : dict ImageEditor output with 'background' and 'layers' Returns ------- PIL.Image or None Extracted mask image (L mode) """ if mask_editor is None: return None try: layers = mask_editor.get("layers", []) if not layers: return None mask_layer = layers[0] if mask_layer is None: return None # Convert to numpy array if isinstance(mask_layer, Image.Image): mask_array = np.array(mask_layer) else: mask_array = np.array(Image.open(mask_layer)) # Handle different formats if len(mask_array.shape) == 3: if mask_array.shape[2] == 4: # RGBA - use alpha channel combined with RGB alpha = mask_array[:, :, 3] gray = cv2.cvtColor(mask_array[:, :, :3], cv2.COLOR_RGB2GRAY) mask_gray = np.maximum(gray, alpha) elif mask_array.shape[2] == 3: # RGB - convert to grayscale mask_gray = cv2.cvtColor(mask_array, cv2.COLOR_RGB2GRAY) else: mask_gray = mask_array[:, :, 0] else: mask_gray = mask_array return Image.fromarray(mask_gray.astype(np.uint8), mode='L') except Exception as e: logger.error(f"Failed to extract mask from editor: {e}") return None