Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from diffusers import DDIMScheduler, UNet2DConditionModel | |
| import spaces | |
| from pipelines import OneStepLaplacianInpaintPipeline | |
| class GCCDemo: | |
| """GCC Demo for Color Constancy using Diffusion Models""" | |
| def __init__(self, model_path="your-username/gcc-color-checker-diffusion", checker_path="color_checker.jpg"): | |
| self.model_path = model_path | |
| self.checker_path = checker_path | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.pipeline = None | |
| self.color_checker = self.load_color_checker() | |
| self.load_model() | |
| def load_color_checker(self, width=180, height=135): | |
| """Load color checker image from file""" | |
| color_checker = cv2.imread(self.checker_path) | |
| # Resize to specified dimensions | |
| return cv2.resize(color_checker, (width, height)) | |
| def load_model(self): | |
| """Load the fine-tuned GCC model""" | |
| try: | |
| base_model = "stabilityai/stable-diffusion-2-inpainting" | |
| scheduler = DDIMScheduler.from_pretrained( | |
| base_model, | |
| subfolder="scheduler", | |
| timestep_spacing="trailing", | |
| prediction_type="v_prediction" | |
| ) | |
| scheduler.set_timesteps(1) | |
| torch_dtype = torch.float32 if self.device == "cpu" else torch.float16 | |
| unet = UNet2DConditionModel.from_pretrained( | |
| self.model_path, | |
| subfolder="unet", | |
| torch_dtype=torch_dtype | |
| ) | |
| self.pipeline = OneStepLaplacianInpaintPipeline.from_pretrained( | |
| base_model, | |
| torch_dtype=torch_dtype, | |
| scheduler=scheduler, | |
| unet=unet | |
| ) | |
| self.pipeline.to(self.device) | |
| print("β Model loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| def gamma_correction(self, image, gamma=2.2): | |
| """Apply gamma correction to convert RAW to sRGB""" | |
| img_array = image.astype(np.float32) / 255.0 | |
| img_corrected = np.power(img_array, 1.0/gamma) | |
| return (img_corrected * 255).astype(np.uint8) | |
| def inverse_gamma_correction(self, image, gamma=2.2): | |
| """Convert sRGB back to linear space""" | |
| img_array = image.astype(np.float32) / 255.0 | |
| return np.power(img_array, gamma) | |
| def apply_white_balance(self, image, illuminant, gamma=2.2): | |
| """Apply white balance correction using estimated illuminant""" | |
| # Calculate gains using green channel as reference | |
| gains = np.array([illuminant[1] / illuminant[0], # R gain | |
| 1.0, # G gain (reference) | |
| illuminant[1] / illuminant[2]]) # B gain | |
| # Determine original image bit depth and max value | |
| original_dtype = image.dtype | |
| if original_dtype == np.uint8: | |
| max_val = 255.0 | |
| elif original_dtype == np.uint16: | |
| max_val = 65535.0 | |
| elif original_dtype == np.float32 or original_dtype == np.float64: | |
| max_val = 1.0 if image.max() <= 1.0 else image.max() | |
| else: | |
| # For other integer types, use the actual max value | |
| max_val = float(np.iinfo(original_dtype).max) if np.issubdtype(original_dtype, np.integer) else float(image.max()) | |
| # Convert BGR to RGB for processing | |
| image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 else image | |
| # Apply white balance correction in linear space | |
| corrected = image_rgb.astype(np.float64) / max_val # Use float64 for better precision | |
| corrected[:,:,0] *= gains[0] # R channel | |
| corrected[:,:,1] *= gains[1] # G channel | |
| corrected[:,:,2] *= gains[2] # B channel | |
| # Clip to prevent overflow | |
| corrected = np.clip(corrected, 0, 1) | |
| # Apply gamma correction (convert from linear to sRGB) | |
| gamma_corrected = np.power(corrected, 1.0/gamma) | |
| # Convert back to original bit depth and BGR format | |
| result_rgb = (gamma_corrected * max_val).astype(original_dtype) | |
| result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) if len(image.shape) == 3 else result_rgb | |
| return result_bgr | |
| def _create_swatch_masks(self, width, height, swatches_h, swatches_v, samples): | |
| """Helper function to create swatch masks""" | |
| samples_half = max(samples / 2, 1) | |
| masks = [] | |
| offset_h = width / swatches_h / 2 | |
| offset_v = height / swatches_v / 2 | |
| for j in np.linspace(offset_v, height - offset_v, swatches_v): | |
| for i in np.linspace(offset_h, width - offset_h, swatches_h): | |
| masks.append(np.array([ | |
| j - samples_half, | |
| j + samples_half, | |
| i - samples_half, | |
| i + samples_half, | |
| ], dtype=np.int32)) | |
| return np.array(masks, dtype=np.int32) | |
| def _extract_swatch_colors(self, image, masks): | |
| """Helper function to extract swatch colors from masks""" | |
| return np.array([ | |
| np.mean(image[mask[0]:mask[1], mask[2]:mask[3], ...], axis=(0, 1)) | |
| for mask in masks | |
| ], dtype=np.float32) | |
| def estimate_illuminant(self, image: np.ndarray, checker_pos: tuple) -> np.ndarray: | |
| """Estimate illuminant color from ColorChecker Classic chart""" | |
| x1, y1, x2, y2 = checker_pos | |
| # Define the 4 corners of the detected checker | |
| coords_pixel = np.array([ | |
| [x1, y1], # top-left | |
| [x2, y1], # top-right | |
| [x2, y2], # bottom-right | |
| [x1, y2], # bottom-left | |
| ], dtype=np.float32) | |
| # Standard working size | |
| working_width = int(x2 - x1) | |
| working_height = int(y2 - y1) | |
| samples = int(working_width / 15) # adjustable sample size | |
| # Destination rectangle (warped top-down view) | |
| rectangle = np.array([ | |
| [0, 0], | |
| [working_width, 0], | |
| [working_width, working_height], | |
| [0, working_height], | |
| ], dtype=np.float32) | |
| # Perspective warp | |
| M = cv2.getPerspectiveTransform(coords_pixel, rectangle) | |
| warped = cv2.warpPerspective(image, M, (working_width, working_height), flags=cv2.INTER_CUBIC) | |
| # Generate swatch masks and extract RGB colors | |
| masks = self._create_swatch_masks(working_width, working_height, 6, 4, samples) | |
| colours = self._extract_swatch_colors(warped, masks) | |
| # Use 20th patch (index 19) which corresponds to a gray patch | |
| estimated_illuminant = colours[19] | |
| return estimated_illuminant | |
| def get_checker_position(self, image_shape, checker_shape): | |
| """Get centered position for color checker""" | |
| h, w = image_shape[:2] | |
| ch, cw = checker_shape[:2] | |
| # Center position | |
| x1 = (w - cw) // 2 | |
| y1 = (h - ch) // 2 | |
| x2 = x1 + cw | |
| y2 = y1 + ch | |
| return (x1, y1, x2, y2) | |
| def process_image(self, input_image): | |
| """Main processing function""" | |
| if input_image is None: | |
| return None, None, "Please upload an image" | |
| if self.pipeline is None: | |
| return None, None, "Model not loaded" | |
| try: | |
| # Read image from filepath to preserve bit depth | |
| if isinstance(input_image, str): | |
| # Read image using OpenCV to preserve bit depth | |
| image_bgr = cv2.imread(input_image, cv2.IMREAD_UNCHANGED) | |
| if image_bgr is None: | |
| return None, None, "Failed to read image file" | |
| # Convert to RGB for PIL (if 3 channels) | |
| if len(image_bgr.shape) == 3: | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| else: | |
| image_rgb = image_bgr | |
| # Create PIL image for original size info | |
| input_pil = Image.fromarray(image_rgb.astype(np.uint8) if image_rgb.dtype != np.uint8 else image_rgb) | |
| original_size = input_pil.size | |
| # Keep original BGR for processing (preserve original bit depth) | |
| raw_original = image_bgr.copy() | |
| else: | |
| # Fallback for PIL input (will be 8-bit) | |
| input_pil = input_image | |
| original_size = input_pil.size | |
| image_np = np.array(input_pil) | |
| image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| raw_original = image_bgr.copy() | |
| # Resize to 512x512 for model (convert to 8-bit for model processing) | |
| if raw_original.dtype != np.uint8: | |
| # Normalize to 8-bit for model processing | |
| if raw_original.dtype == np.uint16: | |
| image_512_input = (raw_original.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8) | |
| else: | |
| # For other bit depths, normalize accordingly | |
| max_val = np.max(raw_original) | |
| image_512_input = (raw_original.astype(np.float32) / max_val * 255.0).astype(np.uint8) | |
| else: | |
| image_512_input = raw_original | |
| image_512 = cv2.resize(image_512_input, (512, 512)) | |
| # Apply gamma correction for model input | |
| image_gamma = self.gamma_correction(image_512) | |
| # Get color checker position | |
| checker_pos = self.get_checker_position((512, 512), self.color_checker.shape) | |
| # Create mask and place reference checker | |
| mask = np.zeros((512, 512), dtype=np.uint8) | |
| mask[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = 255 | |
| image_with_checker = image_gamma.copy() | |
| image_with_checker[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = self.color_checker | |
| # Convert to PIL for model | |
| image_pil = Image.fromarray(cv2.cvtColor(image_with_checker, cv2.COLOR_BGR2RGB)) | |
| mask_pil = Image.fromarray(mask) | |
| # Generate color checker | |
| prompt = "a scene with a color checker that accurately reflects the ambient lighting of the scene." | |
| output_images = self.pipeline( | |
| prompt, | |
| image=image_pil, | |
| mask_image=mask_pil, | |
| generator=torch.Generator(device="cpu").manual_seed(42), | |
| num_inference_steps=1, | |
| guidance_scale=0, | |
| ).images | |
| output_image = output_images[0] | |
| # Convert to linear space and estimate illuminant | |
| output_linear = self.inverse_gamma_correction(np.array(output_image)) | |
| illuminant = self.estimate_illuminant(output_linear, checker_pos) | |
| # Apply white balance to original image (preserve original bit depth) | |
| wb_image = self.apply_white_balance(raw_original, illuminant) | |
| wb_image_rgb = cv2.cvtColor(wb_image, cv2.COLOR_BGR2RGB) | |
| # Convert to PIL Image for Gradio output (will be saved as PNG) | |
| # Ensure proper data type conversion for PNG output | |
| if wb_image_rgb.dtype == np.uint16: | |
| # Convert 16-bit to 8-bit for PNG display while preserving visual quality | |
| wb_image_8bit = (wb_image_rgb.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8) | |
| wb_image_pil = Image.fromarray(wb_image_8bit) | |
| elif wb_image_rgb.dtype != np.uint8: | |
| # Handle other bit depths | |
| max_val = np.max(wb_image_rgb) | |
| wb_image_8bit = (wb_image_rgb.astype(np.float32) / max_val * 255.0).astype(np.uint8) | |
| wb_image_pil = Image.fromarray(wb_image_8bit) | |
| else: | |
| wb_image_pil = Image.fromarray(wb_image_rgb) | |
| info = f"""π Estimated Illuminant (RGB): | |
| R: {illuminant[0]:.6f} | |
| G: {illuminant[1]:.6f} | |
| B: {illuminant[2]:.6f} | |
| βοΈ White Balance Gains: | |
| R: {illuminant[1]/illuminant[0]:.4f} | |
| G: 1.0000 | |
| B: {illuminant[1]/illuminant[2]:.4f}""" | |
| return output_image, wb_image_pil, info | |
| except Exception as e: | |
| return None, None, f"Processing error: {str(e)}" | |
| # Initialize demo | |
| demo_app = GCCDemo( | |
| model_path="StevenChangWei/gcc_train_on_nus8", | |
| checker_path="color_chart.jpg" # Path to your color checker image | |
| ) | |
| def process_wrapper(input_image): | |
| return demo_app.process_image(input_image) | |
| # Create Gradio interface | |
| with gr.Blocks(title="GCC: Generative Color Constancy", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # πΈ GCC: Generative Color Constancy via Diffusing a Color Checker | |
| π [GitHub Repository](https://github.com/chenwei891213/GCC-official) | π [Project Page](https://chenwei891213.github.io/GCC/) | π [Paper](https://arxiv.org/abs/2502.17435) | |
| Upload a RAW image to automatically correct its white balance using our diffusion-based color constancy method. GCC generates a color checker that accurately reflects the scene's ambient lighting, then uses it to estimate the illuminant and apply precise white balance correction. | |
| **Getting Started:** | |
| 1. **Upload Your Image:** Use the image upload box on the left to provide your RAW image (supports various bit depths). | |
| 2. **Process:** Click the "π Process Image" button to start the color constancy correction. | |
| 3. **View Results:** The generated color checker and white balanced result will appear below. | |
| 4. **Download:** Right-click on any result image to save it as PNG to your device. | |
| 5. **Technical Details (Optional):** After processing, you can view detailed processing information in the expandable section below. | |
| """) | |
| # First Row: Input Image and Generated Color Checker | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Input Image") | |
| input_image = gr.Image( | |
| type="filepath", | |
| label="Upload RAW image", | |
| height=400 | |
| ) | |
| # Process Button | |
| process_btn = gr.Button( | |
| "π Process Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π― Generated Color Checker (512Γ512)") | |
| output_checker = gr.Image( | |
| label="Scene-harmonized color checker", | |
| height=400, | |
| format="png" | |
| ) | |
| # Second Row: White Balanced Result (Full Width) | |
| with gr.Row(): | |
| gr.Markdown("### βοΈ White Balanced Result") | |
| with gr.Row(): | |
| output_corrected = gr.Image( | |
| label="Final white balanced image (saves as PNG)", | |
| height=500, | |
| format="png" | |
| ) | |
| # Processing Information (Collapsible) | |
| with gr.Accordion("π Processing Information", open=False): | |
| info_output = gr.Textbox( | |
| label="Processing Results", | |
| lines=8, | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| **Processing Steps:** | |
| 1. **Preprocessing:** Resize input to 512Γ512, apply gamma correction (Ξ³=2.2) | |
| 2. **Color Checker Generation:** Generate scene-harmonized color checker at image center using diffusion model | |
| 3. **Illuminant Estimation:** Extract illuminant from gray patch (patch 20) | |
| 4. **White Balance Correction:** Apply RGB gains in linear space, then apply gamma correction (Ξ³=2.2) for sRGB output | |
| """) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_wrapper, | |
| inputs=[input_image], | |
| outputs=[output_checker, output_corrected, info_output] | |
| ) | |
| gr.Markdown(""" | |
| ### π Citation | |
| ```bibtex | |
| @InProceedings{Chang_2025_CVPR, | |
| author = {Chang, Chen-Wei and Fan, Cheng-De and Chang, Chia-Che and Lo, Yi-Chen and Tseng, Yu-Chee and Huang, Jiun-Long and Liu, Yu-Lun}, | |
| title = {GCC: Generative Color Constancy via Diffusing a Color Checker}, | |
| booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, | |
| month = {June}, | |
| year = {2025}, | |
| pages = {10868-10878} | |
| } | |
| ``` | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |