import gradio as gr import numpy as np import matplotlib.pyplot as plt from matplotlib import gridspec from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation # ------------------------------- # Load Model # ------------------------------- MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024" processor = AutoImageProcessor.from_pretrained(MODEL_ID) model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) # ------------------------------- # Load Labels # ------------------------------- with open("labels.txt", "r", encoding="utf-8") as f: labels_list = [line.strip() for line in f] CITYSCAPES_PALETTE = [ (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32) ] colormap = np.array(CITYSCAPES_PALETTE, dtype=np.uint8) def label_to_color_image(label): return colormap[label] # ------------------------------- # Visualization # ------------------------------- def draw_plot(pred_img, seg_np): fig = plt.figure(figsize=(20, 15)) gs = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) # main overlay plt.subplot(gs[0]) plt.imshow(pred_img) plt.axis("off") # legend LABEL_NAMES = np.asarray(labels_list) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) unique_labels = np.unique(seg_np) ax = plt.subplot(gs[1]) plt.imshow(FULL_COLOR_MAP[unique_labels], interpolation="nearest") ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([]) ax.tick_params(width=0, labelsize=20) return fig # ------------------------------- # Inference # ------------------------------- def run_inference(input_img, alpha): img = Image.fromarray(input_img.astype(np.uint8)).convert("RGB") img_np = np.array(img) # numpy array 변환 # Model inference inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits seg = torch.nn.functional.interpolate( logits, size=img.size[::-1], mode="bilinear", align_corners=False ).argmax(1)[0].cpu().numpy().astype(np.uint8) # Overlay with alpha slider color_seg = colormap[seg] overlay = (img_np * (1 - alpha) + color_seg * alpha).astype(np.uint8) fig = draw_plot(overlay, seg) return fig # ------------------------------- # Gradio UI # ------------------------------- demo = gr.Interface( fn=run_inference, inputs=[ gr.Image(type="numpy", label="Input Image"), gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Overlay Transparency (alpha)") ], outputs=gr.Plot(label="Segmentation Result"), examples=[ ["city1.jpg"], ["city2.jpg"], ["city3.jpg"], ["city4.jpg"], ["city5.jpg"] ], title="SegFormer B2 Cityscapes Segmentation Demo", cache_examples=False # 캐싱 끔 ) if __name__ == "__main__": demo.launch()