import gradio as gr import torch import numpy as np import cv2 from PIL import Image from diffusers import DDIMScheduler, UNet2DConditionModel import spaces from pipelines import OneStepLaplacianInpaintPipeline class GCCDemo: """GCC Demo for Color Constancy using Diffusion Models""" def __init__(self, model_path="your-username/gcc-color-checker-diffusion", checker_path="color_checker.jpg"): self.model_path = model_path self.checker_path = checker_path self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipeline = None self.color_checker = self.load_color_checker() self.load_model() def load_color_checker(self, width=180, height=135): """Load color checker image from file""" color_checker = cv2.imread(self.checker_path) # Resize to specified dimensions return cv2.resize(color_checker, (width, height)) @spaces.GPU def load_model(self): """Load the fine-tuned GCC model""" try: base_model = "stabilityai/stable-diffusion-2-inpainting" scheduler = DDIMScheduler.from_pretrained( base_model, subfolder="scheduler", timestep_spacing="trailing", prediction_type="v_prediction" ) scheduler.set_timesteps(1) torch_dtype = torch.float32 if self.device == "cpu" else torch.float16 unet = UNet2DConditionModel.from_pretrained( self.model_path, subfolder="unet", torch_dtype=torch_dtype ) self.pipeline = OneStepLaplacianInpaintPipeline.from_pretrained( base_model, torch_dtype=torch_dtype, scheduler=scheduler, unet=unet ) self.pipeline.to(self.device) print("✅ Model loaded successfully") except Exception as e: print(f"❌ Error loading model: {e}") def gamma_correction(self, image, gamma=2.2): """Apply gamma correction to convert RAW to sRGB""" img_array = image.astype(np.float32) / 255.0 img_corrected = np.power(img_array, 1.0/gamma) return (img_corrected * 255).astype(np.uint8) def inverse_gamma_correction(self, image, gamma=2.2): """Convert sRGB back to linear space""" img_array = image.astype(np.float32) / 255.0 return np.power(img_array, gamma) def apply_white_balance(self, image, illuminant, gamma=2.2): """Apply white balance correction using estimated illuminant""" # Calculate gains using green channel as reference gains = np.array([illuminant[1] / illuminant[0], # R gain 1.0, # G gain (reference) illuminant[1] / illuminant[2]]) # B gain # Determine original image bit depth and max value original_dtype = image.dtype if original_dtype == np.uint8: max_val = 255.0 elif original_dtype == np.uint16: max_val = 65535.0 elif original_dtype == np.float32 or original_dtype == np.float64: max_val = 1.0 if image.max() <= 1.0 else image.max() else: # For other integer types, use the actual max value max_val = float(np.iinfo(original_dtype).max) if np.issubdtype(original_dtype, np.integer) else float(image.max()) # Convert BGR to RGB for processing image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 else image # Apply white balance correction in linear space corrected = image_rgb.astype(np.float64) / max_val # Use float64 for better precision corrected[:,:,0] *= gains[0] # R channel corrected[:,:,1] *= gains[1] # G channel corrected[:,:,2] *= gains[2] # B channel # Clip to prevent overflow corrected = np.clip(corrected, 0, 1) # Apply gamma correction (convert from linear to sRGB) gamma_corrected = np.power(corrected, 1.0/gamma) # Convert back to original bit depth and BGR format result_rgb = (gamma_corrected * max_val).astype(original_dtype) result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) if len(image.shape) == 3 else result_rgb return result_bgr def _create_swatch_masks(self, width, height, swatches_h, swatches_v, samples): """Helper function to create swatch masks""" samples_half = max(samples / 2, 1) masks = [] offset_h = width / swatches_h / 2 offset_v = height / swatches_v / 2 for j in np.linspace(offset_v, height - offset_v, swatches_v): for i in np.linspace(offset_h, width - offset_h, swatches_h): masks.append(np.array([ j - samples_half, j + samples_half, i - samples_half, i + samples_half, ], dtype=np.int32)) return np.array(masks, dtype=np.int32) def _extract_swatch_colors(self, image, masks): """Helper function to extract swatch colors from masks""" return np.array([ np.mean(image[mask[0]:mask[1], mask[2]:mask[3], ...], axis=(0, 1)) for mask in masks ], dtype=np.float32) def estimate_illuminant(self, image: np.ndarray, checker_pos: tuple) -> np.ndarray: """Estimate illuminant color from ColorChecker Classic chart""" x1, y1, x2, y2 = checker_pos # Define the 4 corners of the detected checker coords_pixel = np.array([ [x1, y1], # top-left [x2, y1], # top-right [x2, y2], # bottom-right [x1, y2], # bottom-left ], dtype=np.float32) # Standard working size working_width = int(x2 - x1) working_height = int(y2 - y1) samples = int(working_width / 15) # adjustable sample size # Destination rectangle (warped top-down view) rectangle = np.array([ [0, 0], [working_width, 0], [working_width, working_height], [0, working_height], ], dtype=np.float32) # Perspective warp M = cv2.getPerspectiveTransform(coords_pixel, rectangle) warped = cv2.warpPerspective(image, M, (working_width, working_height), flags=cv2.INTER_CUBIC) # Generate swatch masks and extract RGB colors masks = self._create_swatch_masks(working_width, working_height, 6, 4, samples) colours = self._extract_swatch_colors(warped, masks) # Use 20th patch (index 19) which corresponds to a gray patch estimated_illuminant = colours[19] return estimated_illuminant def get_checker_position(self, image_shape, checker_shape): """Get centered position for color checker""" h, w = image_shape[:2] ch, cw = checker_shape[:2] # Center position x1 = (w - cw) // 2 y1 = (h - ch) // 2 x2 = x1 + cw y2 = y1 + ch return (x1, y1, x2, y2) @spaces.GPU def process_image(self, input_image): """Main processing function""" if input_image is None: return None, None, "Please upload an image" if self.pipeline is None: return None, None, "Model not loaded" try: # Read image from filepath to preserve bit depth if isinstance(input_image, str): # Read image using OpenCV to preserve bit depth image_bgr = cv2.imread(input_image, cv2.IMREAD_UNCHANGED) if image_bgr is None: return None, None, "Failed to read image file" # Convert to RGB for PIL (if 3 channels) if len(image_bgr.shape) == 3: image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) else: image_rgb = image_bgr # Create PIL image for original size info input_pil = Image.fromarray(image_rgb.astype(np.uint8) if image_rgb.dtype != np.uint8 else image_rgb) original_size = input_pil.size # Keep original BGR for processing (preserve original bit depth) raw_original = image_bgr.copy() else: # Fallback for PIL input (will be 8-bit) input_pil = input_image original_size = input_pil.size image_np = np.array(input_pil) image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) raw_original = image_bgr.copy() # Resize to 512x512 for model (convert to 8-bit for model processing) if raw_original.dtype != np.uint8: # Normalize to 8-bit for model processing if raw_original.dtype == np.uint16: image_512_input = (raw_original.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8) else: # For other bit depths, normalize accordingly max_val = np.max(raw_original) image_512_input = (raw_original.astype(np.float32) / max_val * 255.0).astype(np.uint8) else: image_512_input = raw_original image_512 = cv2.resize(image_512_input, (512, 512)) # Apply gamma correction for model input image_gamma = self.gamma_correction(image_512) # Get color checker position checker_pos = self.get_checker_position((512, 512), self.color_checker.shape) # Create mask and place reference checker mask = np.zeros((512, 512), dtype=np.uint8) mask[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = 255 image_with_checker = image_gamma.copy() image_with_checker[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = self.color_checker # Convert to PIL for model image_pil = Image.fromarray(cv2.cvtColor(image_with_checker, cv2.COLOR_BGR2RGB)) mask_pil = Image.fromarray(mask) # Generate color checker prompt = "a scene with a color checker that accurately reflects the ambient lighting of the scene." output_images = self.pipeline( prompt, image=image_pil, mask_image=mask_pil, generator=torch.Generator(device="cpu").manual_seed(42), num_inference_steps=1, guidance_scale=0, ).images output_image = output_images[0] # Convert to linear space and estimate illuminant output_linear = self.inverse_gamma_correction(np.array(output_image)) illuminant = self.estimate_illuminant(output_linear, checker_pos) # Apply white balance to original image (preserve original bit depth) wb_image = self.apply_white_balance(raw_original, illuminant) wb_image_rgb = cv2.cvtColor(wb_image, cv2.COLOR_BGR2RGB) # Convert to PIL Image for Gradio output (will be saved as PNG) # Ensure proper data type conversion for PNG output if wb_image_rgb.dtype == np.uint16: # Convert 16-bit to 8-bit for PNG display while preserving visual quality wb_image_8bit = (wb_image_rgb.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8) wb_image_pil = Image.fromarray(wb_image_8bit) elif wb_image_rgb.dtype != np.uint8: # Handle other bit depths max_val = np.max(wb_image_rgb) wb_image_8bit = (wb_image_rgb.astype(np.float32) / max_val * 255.0).astype(np.uint8) wb_image_pil = Image.fromarray(wb_image_8bit) else: wb_image_pil = Image.fromarray(wb_image_rgb) info = f"""🔍 Estimated Illuminant (RGB): R: {illuminant[0]:.6f} G: {illuminant[1]:.6f} B: {illuminant[2]:.6f} ⚖️ White Balance Gains: R: {illuminant[1]/illuminant[0]:.4f} G: 1.0000 B: {illuminant[1]/illuminant[2]:.4f}""" return output_image, wb_image_pil, info except Exception as e: return None, None, f"Processing error: {str(e)}" # Initialize demo demo_app = GCCDemo( model_path="StevenChangWei/gcc_train_on_nus8", checker_path="color_chart.jpg" # Path to your color checker image ) def process_wrapper(input_image): return demo_app.process_image(input_image) # Create Gradio interface with gr.Blocks(title="GCC: Generative Color Constancy", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 📸 GCC: Generative Color Constancy via Diffusing a Color Checker 🌟 [GitHub Repository](https://github.com/chenwei891213/GCC-official) | 🚀 [Project Page](https://chenwei891213.github.io/GCC/) | 📄 [Paper](https://arxiv.org/abs/2502.17435) Upload a RAW image to automatically correct its white balance using our diffusion-based color constancy method. GCC generates a color checker that accurately reflects the scene's ambient lighting, then uses it to estimate the illuminant and apply precise white balance correction. **Getting Started:** 1. **Upload Your Image:** Use the image upload box on the left to provide your RAW image (supports various bit depths). 2. **Process:** Click the "🚀 Process Image" button to start the color constancy correction. 3. **View Results:** The generated color checker and white balanced result will appear below. 4. **Download:** Right-click on any result image to save it as PNG to your device. 5. **Technical Details (Optional):** After processing, you can view detailed processing information in the expandable section below. """) # First Row: Input Image and Generated Color Checker with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 Input Image") input_image = gr.Image( type="filepath", label="Upload RAW image", height=400 ) # Process Button process_btn = gr.Button( "🚀 Process Image", variant="primary", size="lg" ) with gr.Column(scale=1): gr.Markdown("### 🎯 Generated Color Checker (512×512)") output_checker = gr.Image( label="Scene-harmonized color checker", height=400, format="png" ) # Second Row: White Balanced Result (Full Width) with gr.Row(): gr.Markdown("### ⚖️ White Balanced Result") with gr.Row(): output_corrected = gr.Image( label="Final white balanced image (saves as PNG)", height=500, format="png" ) # Processing Information (Collapsible) with gr.Accordion("📋 Processing Information", open=False): info_output = gr.Textbox( label="Processing Results", lines=8, interactive=False ) gr.Markdown(""" **Processing Steps:** 1. **Preprocessing:** Resize input to 512×512, apply gamma correction (γ=2.2) 2. **Color Checker Generation:** Generate scene-harmonized color checker at image center using diffusion model 3. **Illuminant Estimation:** Extract illuminant from gray patch (patch 20) 4. **White Balance Correction:** Apply RGB gains in linear space, then apply gamma correction (γ=2.2) for sRGB output """) # Event handlers process_btn.click( fn=process_wrapper, inputs=[input_image], outputs=[output_checker, output_corrected, info_output] ) gr.Markdown(""" ### 📄 Citation ```bibtex @InProceedings{Chang_2025_CVPR, author = {Chang, Chen-Wei and Fan, Cheng-De and Chang, Chia-Che and Lo, Yi-Chen and Tseng, Yu-Chee and Huang, Jiun-Long and Liu, Yu-Lun}, title = {GCC: Generative Color Constancy via Diffusing a Color Checker}, booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, month = {June}, year = {2025}, pages = {10868-10878} } ``` """) if __name__ == "__main__": demo.launch(share=True)