|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib import gridspec |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024" |
|
|
processor = AutoImageProcessor.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("labels.txt", "r", encoding="utf-8") as f: |
|
|
labels_list = [line.strip() for line in f] |
|
|
|
|
|
CITYSCAPES_PALETTE = [ |
|
|
(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), |
|
|
(190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), |
|
|
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), |
|
|
(255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), |
|
|
(0, 0, 230), (119, 11, 32) |
|
|
] |
|
|
colormap = np.array(CITYSCAPES_PALETTE, dtype=np.uint8) |
|
|
|
|
|
def label_to_color_image(label): |
|
|
return colormap[label] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_plot(pred_img, seg_np): |
|
|
fig = plt.figure(figsize=(20, 15)) |
|
|
gs = gridspec.GridSpec(1, 2, width_ratios=[6, 1]) |
|
|
|
|
|
|
|
|
plt.subplot(gs[0]) |
|
|
plt.imshow(pred_img) |
|
|
plt.axis("off") |
|
|
|
|
|
|
|
|
LABEL_NAMES = np.asarray(labels_list) |
|
|
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) |
|
|
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP) |
|
|
|
|
|
unique_labels = np.unique(seg_np) |
|
|
ax = plt.subplot(gs[1]) |
|
|
plt.imshow(FULL_COLOR_MAP[unique_labels], interpolation="nearest") |
|
|
ax.yaxis.tick_right() |
|
|
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) |
|
|
plt.xticks([]) |
|
|
ax.tick_params(width=0, labelsize=20) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_inference(input_img, alpha): |
|
|
img = Image.fromarray(input_img.astype(np.uint8)).convert("RGB") |
|
|
img_np = np.array(img) |
|
|
|
|
|
|
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
seg = torch.nn.functional.interpolate( |
|
|
logits, |
|
|
size=img.size[::-1], |
|
|
mode="bilinear", |
|
|
align_corners=False |
|
|
).argmax(1)[0].cpu().numpy().astype(np.uint8) |
|
|
|
|
|
|
|
|
color_seg = colormap[seg] |
|
|
overlay = (img_np * (1 - alpha) + color_seg * alpha).astype(np.uint8) |
|
|
|
|
|
fig = draw_plot(overlay, seg) |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=run_inference, |
|
|
inputs=[ |
|
|
gr.Image(type="numpy", label="Input Image"), |
|
|
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Overlay Transparency (alpha)") |
|
|
], |
|
|
outputs=gr.Plot(label="Segmentation Result"), |
|
|
examples=[ |
|
|
["city1.jpg"], |
|
|
["city2.jpg"], |
|
|
["city3.jpg"], |
|
|
["city4.jpg"], |
|
|
["city5.jpg"] |
|
|
], |
|
|
title="SegFormer B2 Cityscapes Segmentation Demo", |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|