Spaces:
Sleeping
Sleeping
| import torch | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| from kornia.color import rgb_to_lab, lab_to_rgb | |
| REPO_ID = "ayushshah/imagecolorization" | |
| WEIGHTS_FILE = "model.safetensors" | |
| ARCHITECTURE_FILE = "model.py" | |
| # Download architecture file | |
| hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=ARCHITECTURE_FILE, | |
| local_dir=".", | |
| local_dir_use_symlinks=False | |
| ) | |
| # Downloading the weights | |
| weights_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=WEIGHTS_FILE | |
| ) | |
| # Initialize the model | |
| from model import UNet | |
| model = UNet() | |
| state_dict = load_file(weights_path) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # Center crop and resize to 224x224 | |
| def prepare_input(image): | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| pil_image = Image.fromarray(image) | |
| side = min(pil_image.size) | |
| square = ImageOps.fit( | |
| pil_image, | |
| (side, side), | |
| centering=(0.5, 0.5), | |
| ) | |
| resized = square.resize((224, 224), Image.Resampling.BICUBIC) | |
| return np.array(resized) | |
| # Colorize the image | |
| def colorize(image): | |
| image = image / 255.0 | |
| img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() | |
| lab_tensor = rgb_to_lab(img_tensor) | |
| L = lab_tensor[:, 0:1, :, :] | |
| L_normalized = (L / 100.0) | |
| with torch.no_grad(): | |
| ab_pred = model(L_normalized) | |
| ab_pred = (ab_pred+1)*255.0/2-128.0 | |
| combined_lab = torch.cat([L, ab_pred], dim=1) | |
| colorized_rgb = lab_to_rgb(combined_lab) | |
| return colorized_rgb.squeeze().permute(1, 2, 0).numpy() | |
| def clear_images(): | |
| return None, None | |
| # Gradio interface | |
| with gr.Blocks(title="Image Colorization") as demo: | |
| gr.HTML("<h1 style='text-align: center;'>Image Colorization using UNet</h1>") | |
| gr.Markdown( | |
| "Upload a square image. If the image is not square, it will be center-cropped to a square image before resizing to 224x224." | |
| ) | |
| gr.Markdown( | |
| "The input image will also be converted to the LAB color space and the L channel will be given as input to the model." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="numpy", | |
| label="Grayscale Input", | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| output_image = gr.Image(type="numpy", label="Colorized Output",image_mode='RGB') | |
| input_image.upload(prepare_input, input_image, input_image) | |
| submit_btn.click(colorize, input_image, output_image) | |
| clear_btn.click(clear_images, None, [input_image, output_image]) | |
| gr.Markdown( | |
| "This Huggingface space is running entirely on CPU. For faster performance, consider running it locally with a GPU or use Google Colab/Kaggle notebooks." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |