GGKK123 / app.py
TOT7's picture
Fix Gradio slider and examples caching
4cb810d
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
# -------------------------------
# Load Model
# -------------------------------
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
# -------------------------------
# Load Labels
# -------------------------------
with open("labels.txt", "r", encoding="utf-8") as f:
labels_list = [line.strip() for line in f]
CITYSCAPES_PALETTE = [
(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
(255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),
(0, 0, 230), (119, 11, 32)
]
colormap = np.array(CITYSCAPES_PALETTE, dtype=np.uint8)
def label_to_color_image(label):
return colormap[label]
# -------------------------------
# Visualization
# -------------------------------
def draw_plot(pred_img, seg_np):
fig = plt.figure(figsize=(20, 15))
gs = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
# main overlay
plt.subplot(gs[0])
plt.imshow(pred_img)
plt.axis("off")
# legend
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg_np)
ax = plt.subplot(gs[1])
plt.imshow(FULL_COLOR_MAP[unique_labels], interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([])
ax.tick_params(width=0, labelsize=20)
return fig
# -------------------------------
# Inference
# -------------------------------
def run_inference(input_img, alpha):
img = Image.fromarray(input_img.astype(np.uint8)).convert("RGB")
img_np = np.array(img) # numpy array λ³€ν™˜
# Model inference
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
seg = torch.nn.functional.interpolate(
logits,
size=img.size[::-1],
mode="bilinear",
align_corners=False
).argmax(1)[0].cpu().numpy().astype(np.uint8)
# Overlay with alpha slider
color_seg = colormap[seg]
overlay = (img_np * (1 - alpha) + color_seg * alpha).astype(np.uint8)
fig = draw_plot(overlay, seg)
return fig
# -------------------------------
# Gradio UI
# -------------------------------
demo = gr.Interface(
fn=run_inference,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Overlay Transparency (alpha)")
],
outputs=gr.Plot(label="Segmentation Result"),
examples=[
["city1.jpg"],
["city2.jpg"],
["city3.jpg"],
["city4.jpg"],
["city5.jpg"]
],
title="SegFormer B2 Cityscapes Segmentation Demo",
cache_examples=False # 캐싱 끔
)
if __name__ == "__main__":
demo.launch()