File size: 2,987 Bytes
7fc2cce
 
 
08fca74
7fc2cce
 
 
08fca74
7fc2cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fca74
7fc2cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fca74
7fc2cce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08fca74
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
from kornia.color import rgb_to_lab, lab_to_rgb


REPO_ID = "ayushshah/imagecolorization"
WEIGHTS_FILE = "model.safetensors"
ARCHITECTURE_FILE = "model.py"


# Download architecture file
hf_hub_download(
    repo_id=REPO_ID,
    filename=ARCHITECTURE_FILE,
    local_dir=".",
    local_dir_use_symlinks=False
)

# Downloading the weights
weights_path = hf_hub_download(
    repo_id=REPO_ID,
    filename=WEIGHTS_FILE
)


# Initialize the model
from model import UNet

model = UNet()
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
model.eval()


# Center crop and resize to 224x224
def prepare_input(image):
    if image is None:
        raise gr.Error("Please upload an image.")
    pil_image = Image.fromarray(image)
    side = min(pil_image.size)
    square = ImageOps.fit(
        pil_image,
        (side, side),
        centering=(0.5, 0.5),
    )
    resized = square.resize((224, 224), Image.Resampling.BICUBIC)
    return np.array(resized)


# Colorize the image
def colorize(image):
    image = image / 255.0
    img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float()

    lab_tensor = rgb_to_lab(img_tensor)

    L = lab_tensor[:, 0:1, :, :]
    L_normalized = (L / 100.0)

    with torch.no_grad():
        ab_pred = model(L_normalized)
    
    ab_pred = (ab_pred+1)*255.0/2-128.0
    combined_lab = torch.cat([L, ab_pred], dim=1)
    colorized_rgb = lab_to_rgb(combined_lab)
    return colorized_rgb.squeeze().permute(1, 2, 0).numpy()


def clear_images():
    return None, None


# Gradio interface
with gr.Blocks(title="Image Colorization") as demo:
    gr.HTML("<h1 style='text-align: center;'>Image Colorization using UNet</h1>")
    gr.Markdown(
        "Upload a square image. If the image is not square, it will be center-cropped to a square image before resizing to 224x224."
    )
    gr.Markdown(
        "The input image will also be converted to the LAB color space and the L channel will be given as input to the model."
    )
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                type="numpy",
                label="Grayscale Input",
            )
            with gr.Row():
                clear_btn = gr.Button("Clear")
                submit_btn = gr.Button("Submit", variant="primary")
        output_image = gr.Image(type="numpy", label="Colorized Output",image_mode='RGB')
    input_image.upload(prepare_input, input_image, input_image)
    submit_btn.click(colorize, input_image, output_image)
    clear_btn.click(clear_images, None, [input_image, output_image])
    gr.Markdown(
        "This Huggingface space is running entirely on CPU. For faster performance, consider running it locally with a GPU or use Google Colab/Kaggle notebooks."
    )

if __name__ == "__main__":
    demo.launch()