File size: 3,330 Bytes
4fa07c4 4cb810d 4fa07c4 4cb810d 4fa07c4 4cb810d 4fa07c4 fe82afd 4fa07c4 4cb810d 4fa07c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
# -------------------------------
# Load Model
# -------------------------------
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
# -------------------------------
# Load Labels
# -------------------------------
with open("labels.txt", "r", encoding="utf-8") as f:
labels_list = [line.strip() for line in f]
CITYSCAPES_PALETTE = [
(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
(255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),
(0, 0, 230), (119, 11, 32)
]
colormap = np.array(CITYSCAPES_PALETTE, dtype=np.uint8)
def label_to_color_image(label):
return colormap[label]
# -------------------------------
# Visualization
# -------------------------------
def draw_plot(pred_img, seg_np):
fig = plt.figure(figsize=(20, 15))
gs = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
# main overlay
plt.subplot(gs[0])
plt.imshow(pred_img)
plt.axis("off")
# legend
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg_np)
ax = plt.subplot(gs[1])
plt.imshow(FULL_COLOR_MAP[unique_labels], interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([])
ax.tick_params(width=0, labelsize=20)
return fig
# -------------------------------
# Inference
# -------------------------------
def run_inference(input_img, alpha):
img = Image.fromarray(input_img.astype(np.uint8)).convert("RGB")
img_np = np.array(img) # numpy array ๋ณํ
# Model inference
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
seg = torch.nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False
).argmax(1)[0].cpu().numpy().astype(np.uint8)
# Overlay with alpha slider
color_seg = colormap[seg]
overlay = (img_np * (1 - alpha) + color_seg * alpha).astype(np.uint8)
fig = draw_plot(overlay, seg)
return fig
# -------------------------------
# Gradio UI
# -------------------------------
demo = gr.Interface(
fn=run_inference,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Overlay Transparency (alpha)")
],
outputs=gr.Plot(label="Segmentation Result"),
examples=[
["city1.jpg"],
["city2.jpg"],
["city3.jpg"],
["city4.jpg"],
["city5.jpg"]
],
title="SegFormer B2 Cityscapes Segmentation Demo",
cache_examples=False # ์บ์ฑ ๋
)
if __name__ == "__main__":
demo.launch()
|