GCC-Demo / app.py
StevenChangWei's picture
Add application file
58133bd
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from diffusers import DDIMScheduler, UNet2DConditionModel
import spaces
from pipelines import OneStepLaplacianInpaintPipeline
class GCCDemo:
"""GCC Demo for Color Constancy using Diffusion Models"""
def __init__(self, model_path="your-username/gcc-color-checker-diffusion", checker_path="color_checker.jpg"):
self.model_path = model_path
self.checker_path = checker_path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipeline = None
self.color_checker = self.load_color_checker()
self.load_model()
def load_color_checker(self, width=180, height=135):
"""Load color checker image from file"""
color_checker = cv2.imread(self.checker_path)
# Resize to specified dimensions
return cv2.resize(color_checker, (width, height))
@spaces.GPU
def load_model(self):
"""Load the fine-tuned GCC model"""
try:
base_model = "stabilityai/stable-diffusion-2-inpainting"
scheduler = DDIMScheduler.from_pretrained(
base_model,
subfolder="scheduler",
timestep_spacing="trailing",
prediction_type="v_prediction"
)
scheduler.set_timesteps(1)
torch_dtype = torch.float32 if self.device == "cpu" else torch.float16
unet = UNet2DConditionModel.from_pretrained(
self.model_path,
subfolder="unet",
torch_dtype=torch_dtype
)
self.pipeline = OneStepLaplacianInpaintPipeline.from_pretrained(
base_model,
torch_dtype=torch_dtype,
scheduler=scheduler,
unet=unet
)
self.pipeline.to(self.device)
print("βœ… Model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
def gamma_correction(self, image, gamma=2.2):
"""Apply gamma correction to convert RAW to sRGB"""
img_array = image.astype(np.float32) / 255.0
img_corrected = np.power(img_array, 1.0/gamma)
return (img_corrected * 255).astype(np.uint8)
def inverse_gamma_correction(self, image, gamma=2.2):
"""Convert sRGB back to linear space"""
img_array = image.astype(np.float32) / 255.0
return np.power(img_array, gamma)
def apply_white_balance(self, image, illuminant, gamma=2.2):
"""Apply white balance correction using estimated illuminant"""
# Calculate gains using green channel as reference
gains = np.array([illuminant[1] / illuminant[0], # R gain
1.0, # G gain (reference)
illuminant[1] / illuminant[2]]) # B gain
# Determine original image bit depth and max value
original_dtype = image.dtype
if original_dtype == np.uint8:
max_val = 255.0
elif original_dtype == np.uint16:
max_val = 65535.0
elif original_dtype == np.float32 or original_dtype == np.float64:
max_val = 1.0 if image.max() <= 1.0 else image.max()
else:
# For other integer types, use the actual max value
max_val = float(np.iinfo(original_dtype).max) if np.issubdtype(original_dtype, np.integer) else float(image.max())
# Convert BGR to RGB for processing
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 else image
# Apply white balance correction in linear space
corrected = image_rgb.astype(np.float64) / max_val # Use float64 for better precision
corrected[:,:,0] *= gains[0] # R channel
corrected[:,:,1] *= gains[1] # G channel
corrected[:,:,2] *= gains[2] # B channel
# Clip to prevent overflow
corrected = np.clip(corrected, 0, 1)
# Apply gamma correction (convert from linear to sRGB)
gamma_corrected = np.power(corrected, 1.0/gamma)
# Convert back to original bit depth and BGR format
result_rgb = (gamma_corrected * max_val).astype(original_dtype)
result_bgr = cv2.cvtColor(result_rgb, cv2.COLOR_RGB2BGR) if len(image.shape) == 3 else result_rgb
return result_bgr
def _create_swatch_masks(self, width, height, swatches_h, swatches_v, samples):
"""Helper function to create swatch masks"""
samples_half = max(samples / 2, 1)
masks = []
offset_h = width / swatches_h / 2
offset_v = height / swatches_v / 2
for j in np.linspace(offset_v, height - offset_v, swatches_v):
for i in np.linspace(offset_h, width - offset_h, swatches_h):
masks.append(np.array([
j - samples_half,
j + samples_half,
i - samples_half,
i + samples_half,
], dtype=np.int32))
return np.array(masks, dtype=np.int32)
def _extract_swatch_colors(self, image, masks):
"""Helper function to extract swatch colors from masks"""
return np.array([
np.mean(image[mask[0]:mask[1], mask[2]:mask[3], ...], axis=(0, 1))
for mask in masks
], dtype=np.float32)
def estimate_illuminant(self, image: np.ndarray, checker_pos: tuple) -> np.ndarray:
"""Estimate illuminant color from ColorChecker Classic chart"""
x1, y1, x2, y2 = checker_pos
# Define the 4 corners of the detected checker
coords_pixel = np.array([
[x1, y1], # top-left
[x2, y1], # top-right
[x2, y2], # bottom-right
[x1, y2], # bottom-left
], dtype=np.float32)
# Standard working size
working_width = int(x2 - x1)
working_height = int(y2 - y1)
samples = int(working_width / 15) # adjustable sample size
# Destination rectangle (warped top-down view)
rectangle = np.array([
[0, 0],
[working_width, 0],
[working_width, working_height],
[0, working_height],
], dtype=np.float32)
# Perspective warp
M = cv2.getPerspectiveTransform(coords_pixel, rectangle)
warped = cv2.warpPerspective(image, M, (working_width, working_height), flags=cv2.INTER_CUBIC)
# Generate swatch masks and extract RGB colors
masks = self._create_swatch_masks(working_width, working_height, 6, 4, samples)
colours = self._extract_swatch_colors(warped, masks)
# Use 20th patch (index 19) which corresponds to a gray patch
estimated_illuminant = colours[19]
return estimated_illuminant
def get_checker_position(self, image_shape, checker_shape):
"""Get centered position for color checker"""
h, w = image_shape[:2]
ch, cw = checker_shape[:2]
# Center position
x1 = (w - cw) // 2
y1 = (h - ch) // 2
x2 = x1 + cw
y2 = y1 + ch
return (x1, y1, x2, y2)
@spaces.GPU
def process_image(self, input_image):
"""Main processing function"""
if input_image is None:
return None, None, "Please upload an image"
if self.pipeline is None:
return None, None, "Model not loaded"
try:
# Read image from filepath to preserve bit depth
if isinstance(input_image, str):
# Read image using OpenCV to preserve bit depth
image_bgr = cv2.imread(input_image, cv2.IMREAD_UNCHANGED)
if image_bgr is None:
return None, None, "Failed to read image file"
# Convert to RGB for PIL (if 3 channels)
if len(image_bgr.shape) == 3:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
else:
image_rgb = image_bgr
# Create PIL image for original size info
input_pil = Image.fromarray(image_rgb.astype(np.uint8) if image_rgb.dtype != np.uint8 else image_rgb)
original_size = input_pil.size
# Keep original BGR for processing (preserve original bit depth)
raw_original = image_bgr.copy()
else:
# Fallback for PIL input (will be 8-bit)
input_pil = input_image
original_size = input_pil.size
image_np = np.array(input_pil)
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
raw_original = image_bgr.copy()
# Resize to 512x512 for model (convert to 8-bit for model processing)
if raw_original.dtype != np.uint8:
# Normalize to 8-bit for model processing
if raw_original.dtype == np.uint16:
image_512_input = (raw_original.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8)
else:
# For other bit depths, normalize accordingly
max_val = np.max(raw_original)
image_512_input = (raw_original.astype(np.float32) / max_val * 255.0).astype(np.uint8)
else:
image_512_input = raw_original
image_512 = cv2.resize(image_512_input, (512, 512))
# Apply gamma correction for model input
image_gamma = self.gamma_correction(image_512)
# Get color checker position
checker_pos = self.get_checker_position((512, 512), self.color_checker.shape)
# Create mask and place reference checker
mask = np.zeros((512, 512), dtype=np.uint8)
mask[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = 255
image_with_checker = image_gamma.copy()
image_with_checker[checker_pos[1]:checker_pos[3], checker_pos[0]:checker_pos[2]] = self.color_checker
# Convert to PIL for model
image_pil = Image.fromarray(cv2.cvtColor(image_with_checker, cv2.COLOR_BGR2RGB))
mask_pil = Image.fromarray(mask)
# Generate color checker
prompt = "a scene with a color checker that accurately reflects the ambient lighting of the scene."
output_images = self.pipeline(
prompt,
image=image_pil,
mask_image=mask_pil,
generator=torch.Generator(device="cpu").manual_seed(42),
num_inference_steps=1,
guidance_scale=0,
).images
output_image = output_images[0]
# Convert to linear space and estimate illuminant
output_linear = self.inverse_gamma_correction(np.array(output_image))
illuminant = self.estimate_illuminant(output_linear, checker_pos)
# Apply white balance to original image (preserve original bit depth)
wb_image = self.apply_white_balance(raw_original, illuminant)
wb_image_rgb = cv2.cvtColor(wb_image, cv2.COLOR_BGR2RGB)
# Convert to PIL Image for Gradio output (will be saved as PNG)
# Ensure proper data type conversion for PNG output
if wb_image_rgb.dtype == np.uint16:
# Convert 16-bit to 8-bit for PNG display while preserving visual quality
wb_image_8bit = (wb_image_rgb.astype(np.float32) / 65535.0 * 255.0).astype(np.uint8)
wb_image_pil = Image.fromarray(wb_image_8bit)
elif wb_image_rgb.dtype != np.uint8:
# Handle other bit depths
max_val = np.max(wb_image_rgb)
wb_image_8bit = (wb_image_rgb.astype(np.float32) / max_val * 255.0).astype(np.uint8)
wb_image_pil = Image.fromarray(wb_image_8bit)
else:
wb_image_pil = Image.fromarray(wb_image_rgb)
info = f"""πŸ” Estimated Illuminant (RGB):
R: {illuminant[0]:.6f}
G: {illuminant[1]:.6f}
B: {illuminant[2]:.6f}
βš–οΈ White Balance Gains:
R: {illuminant[1]/illuminant[0]:.4f}
G: 1.0000
B: {illuminant[1]/illuminant[2]:.4f}"""
return output_image, wb_image_pil, info
except Exception as e:
return None, None, f"Processing error: {str(e)}"
# Initialize demo
demo_app = GCCDemo(
model_path="StevenChangWei/gcc_train_on_nus8",
checker_path="color_chart.jpg" # Path to your color checker image
)
def process_wrapper(input_image):
return demo_app.process_image(input_image)
# Create Gradio interface
with gr.Blocks(title="GCC: Generative Color Constancy", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ“Έ GCC: Generative Color Constancy via Diffusing a Color Checker
🌟 [GitHub Repository](https://github.com/chenwei891213/GCC-official) | πŸš€ [Project Page](https://chenwei891213.github.io/GCC/) | πŸ“„ [Paper](https://arxiv.org/abs/2502.17435)
Upload a RAW image to automatically correct its white balance using our diffusion-based color constancy method. GCC generates a color checker that accurately reflects the scene's ambient lighting, then uses it to estimate the illuminant and apply precise white balance correction.
**Getting Started:**
1. **Upload Your Image:** Use the image upload box on the left to provide your RAW image (supports various bit depths).
2. **Process:** Click the "πŸš€ Process Image" button to start the color constancy correction.
3. **View Results:** The generated color checker and white balanced result will appear below.
4. **Download:** Right-click on any result image to save it as PNG to your device.
5. **Technical Details (Optional):** After processing, you can view detailed processing information in the expandable section below.
""")
# First Row: Input Image and Generated Color Checker
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Input Image")
input_image = gr.Image(
type="filepath",
label="Upload RAW image",
height=400
)
# Process Button
process_btn = gr.Button(
"πŸš€ Process Image",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### 🎯 Generated Color Checker (512Γ—512)")
output_checker = gr.Image(
label="Scene-harmonized color checker",
height=400,
format="png"
)
# Second Row: White Balanced Result (Full Width)
with gr.Row():
gr.Markdown("### βš–οΈ White Balanced Result")
with gr.Row():
output_corrected = gr.Image(
label="Final white balanced image (saves as PNG)",
height=500,
format="png"
)
# Processing Information (Collapsible)
with gr.Accordion("πŸ“‹ Processing Information", open=False):
info_output = gr.Textbox(
label="Processing Results",
lines=8,
interactive=False
)
gr.Markdown("""
**Processing Steps:**
1. **Preprocessing:** Resize input to 512Γ—512, apply gamma correction (Ξ³=2.2)
2. **Color Checker Generation:** Generate scene-harmonized color checker at image center using diffusion model
3. **Illuminant Estimation:** Extract illuminant from gray patch (patch 20)
4. **White Balance Correction:** Apply RGB gains in linear space, then apply gamma correction (Ξ³=2.2) for sRGB output
""")
# Event handlers
process_btn.click(
fn=process_wrapper,
inputs=[input_image],
outputs=[output_checker, output_corrected, info_output]
)
gr.Markdown("""
### πŸ“„ Citation
```bibtex
@InProceedings{Chang_2025_CVPR,
author = {Chang, Chen-Wei and Fan, Cheng-De and Chang, Chia-Che and Lo, Yi-Chen and Tseng, Yu-Chee and Huang, Jiun-Long and Liu, Yu-Lun},
title = {GCC: Generative Color Constancy via Diffusing a Color Checker},
booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
month = {June},
year = {2025},
pages = {10868-10878}
}
```
""")
if __name__ == "__main__":
demo.launch(share=True)