File size: 3,330 Bytes
4fa07c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb810d
4fa07c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb810d
4fa07c4
 
 
 
 
 
 
 
 
 
 
4cb810d
4fa07c4
 
 
fe82afd
 
 
 
 
4fa07c4
4cb810d
 
4fa07c4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation

# -------------------------------
# Load Model
# -------------------------------
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)

# -------------------------------
# Load Labels
# -------------------------------
with open("labels.txt", "r", encoding="utf-8") as f:
    labels_list = [line.strip() for line in f]

CITYSCAPES_PALETTE = [
    (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
    (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
    (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
    (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),
    (0, 0, 230), (119, 11, 32)
]
colormap = np.array(CITYSCAPES_PALETTE, dtype=np.uint8)

def label_to_color_image(label):
    return colormap[label]

# -------------------------------
# Visualization
# -------------------------------
def draw_plot(pred_img, seg_np):
    fig = plt.figure(figsize=(20, 15))
    gs = gridspec.GridSpec(1, 2, width_ratios=[6, 1])

    # main overlay
    plt.subplot(gs[0])
    plt.imshow(pred_img)
    plt.axis("off")

    # legend
    LABEL_NAMES = np.asarray(labels_list)
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)

    unique_labels = np.unique(seg_np)
    ax = plt.subplot(gs[1])
    plt.imshow(FULL_COLOR_MAP[unique_labels], interpolation="nearest")
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([])
    ax.tick_params(width=0, labelsize=20)

    return fig

# -------------------------------
# Inference
# -------------------------------
def run_inference(input_img, alpha):
    img = Image.fromarray(input_img.astype(np.uint8)).convert("RGB")
    img_np = np.array(img)  # numpy array ๋ณ€ํ™˜

    # Model inference
    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    seg = torch.nn.functional.interpolate(
        logits,
        size=img.size[::-1],
        mode="bilinear",
        align_corners=False
    ).argmax(1)[0].cpu().numpy().astype(np.uint8)

    # Overlay with alpha slider
    color_seg = colormap[seg]
    overlay = (img_np * (1 - alpha) + color_seg * alpha).astype(np.uint8)

    fig = draw_plot(overlay, seg)
    return fig

# -------------------------------
# Gradio UI
# -------------------------------
demo = gr.Interface(
    fn=run_inference,
    inputs=[
        gr.Image(type="numpy", label="Input Image"),
        gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Overlay Transparency (alpha)")
    ],
    outputs=gr.Plot(label="Segmentation Result"),
    examples=[
        ["city1.jpg"],
        ["city2.jpg"],
        ["city3.jpg"],
        ["city4.jpg"],
        ["city5.jpg"]
    ],
    title="SegFormer B2 Cityscapes Segmentation Demo",
    cache_examples=False  # ์บ์‹ฑ ๋”
)

if __name__ == "__main__":
    demo.launch()