File size: 3,859 Bytes
cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 5fa0bf0 cb26699 31ea5d8 011f2fa 31ea5d8 b0928d2 011f2fa 29be757 5fa0bf0 cb26699 5fa0bf0 591049b 5fa0bf0 cb26699 5fa0bf0 cb26699 fdea0db cb26699 5fa0bf0 5a8ba2c 5fa0bf0 cb26699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import time
MODEL_ID = "nvidia/segformer-b2-finetuned-cityscapes-1024-1024"
processor = AutoImageProcessor.from_pretrained(MODEL_ID)
model = AutoModelForSemanticSegmentation.from_pretrained(MODEL_ID)
def ade_palette():
return [
[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153],
[153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], [152, 251, 152],
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]
]
labels_list = []
with open("labels.txt", "r", encoding="utf-8") as fp:
for line in fp:
labels_list.append(line.rstrip("\n"))
colormap = np.asarray(ade_palette(), dtype=np.uint8)
def label_to_color_image(label):
if label.ndim != 2:
raise ValueError("Expect 2-D input label")
if np.max(label) >= len(colormap):
raise ValueError("label value too large.")
return colormap[label]
def draw_plot(pred_img, seg_np):
fig = plt.figure(figsize=(20, 15))
grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
plt.subplot(grid_spec[0])
plt.imshow(pred_img)
plt.axis('off')
LABEL_NAMES = np.asarray(labels_list)
FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
unique_labels = np.unique(seg_np.astype("uint8"))
ax = plt.subplot(grid_spec[1])
plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
ax.yaxis.tick_right()
plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
plt.xticks([], [])
ax.tick_params(width=0.0, labelsize=25)
return fig
def run_inference(input_img, alpha=0.5):
start_time = time.time()
img = Image.fromarray(input_img.astype(np.uint8)) if isinstance(input_img, np.ndarray) else input_img
if img.mode != "RGB":
img = img.convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits, size=img.size[::-1], mode="bilinear", align_corners=False
)
seg = upsampled.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
color_seg = colormap[seg]
# alpha ๋ณ์๋ฅผ ์ฌ์ฉํด ํฌ๋ช
๋ ์กฐ์
image_weight = 1.0 - alpha
overlay_weight = alpha
pred_img = (np.array(img) * image_weight + color_seg * overlay_weight).astype(np.uint8)
fig = draw_plot(pred_img, seg)
return fig
custom_theme = gr.themes.Soft(
primary_hue="emerald", secondary_hue="teal", neutral_hue="slate"
).set(
body_background_fill="#f9fafb",
body_text_color="#1f2937",
button_primary_background_fill="#10b981",
button_primary_text_color="#ffffff",
block_background_fill="#ffffff"
)
demo = gr.Interface(
fn=run_inference,
inputs=[
gr.Image(type="numpy", label="์ด๋ฏธ์ง ์
๋ ฅ"),
gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="ํฌ๋ช
๋ ์กฐ์ ")
],
outputs=gr.Plot(label="๊ฒฐ๊ณผ"),
examples=[
["city1.png", 0.5],
["city2.png", 0.5],
["city3.jpg", 0.5],
["city4.jpeg", 0.5],
["city5.jpg", 0.5]
],
flagging_mode="never",
cache_examples=False,
theme=custom_theme
)
title = "City Segment"
description = ("""segformer-b2๋ชจ๋ธ์ ์ด์ฉ ๋์ ์ด๋ฏธ์ง ๋ถํ ์๊ฐ.<br>
์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด ๋๋ก, ๊ฑด๋ฌผ, ์ฐจ๋, ์ฌ๋ ๋ฑ ๊ฐ์ฒด๋ณ๋ก ์์์ผ๋ก ๊ตฌ๋ถํด์ค๋๋ค.""")
if __name__ == "__main__":
demo.launch() |