Spaces:
Sleeping
Sleeping
File size: 2,987 Bytes
7fc2cce 08fca74 7fc2cce 08fca74 7fc2cce 08fca74 7fc2cce 08fca74 7fc2cce 08fca74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
from kornia.color import rgb_to_lab, lab_to_rgb
REPO_ID = "ayushshah/imagecolorization"
WEIGHTS_FILE = "model.safetensors"
ARCHITECTURE_FILE = "model.py"
# Download architecture file
hf_hub_download(
repo_id=REPO_ID,
filename=ARCHITECTURE_FILE,
local_dir=".",
local_dir_use_symlinks=False
)
# Downloading the weights
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=WEIGHTS_FILE
)
# Initialize the model
from model import UNet
model = UNet()
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()
# Center crop and resize to 224x224
def prepare_input(image):
if image is None:
raise gr.Error("Please upload an image.")
pil_image = Image.fromarray(image)
side = min(pil_image.size)
square = ImageOps.fit(
pil_image,
(side, side),
centering=(0.5, 0.5),
)
resized = square.resize((224, 224), Image.Resampling.BICUBIC)
return np.array(resized)
# Colorize the image
def colorize(image):
image = image / 255.0
img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()
lab_tensor = rgb_to_lab(img_tensor)
L = lab_tensor[:, 0:1, :, :]
L_normalized = (L / 100.0)
with torch.no_grad():
ab_pred = model(L_normalized)
ab_pred = (ab_pred+1)*255.0/2-128.0
combined_lab = torch.cat([L, ab_pred], dim=1)
colorized_rgb = lab_to_rgb(combined_lab)
return colorized_rgb.squeeze().permute(1, 2, 0).numpy()
def clear_images():
return None, None
# Gradio interface
with gr.Blocks(title="Image Colorization") as demo:
gr.HTML("<h1 style='text-align: center;'>Image Colorization using UNet</h1>")
gr.Markdown(
"Upload a square image. If the image is not square, it will be center-cropped to a square image before resizing to 224x224."
)
gr.Markdown(
"The input image will also be converted to the LAB color space and the L channel will be given as input to the model."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(
type="numpy",
label="Grayscale Input",
)
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
output_image = gr.Image(type="numpy", label="Colorized Output",image_mode='RGB')
input_image.upload(prepare_input, input_image, input_image)
submit_btn.click(colorize, input_image, output_image)
clear_btn.click(clear_images, None, [input_image, output_image])
gr.Markdown(
"This Huggingface space is running entirely on CPU. For faster performance, consider running it locally with a GPU or use Google Colab/Kaggle notebooks."
)
if __name__ == "__main__":
demo.launch() |