Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import spaces
|
| 2 |
import multiprocessing as mp
|
| 3 |
import numpy as np
|
|
@@ -18,6 +22,17 @@ import gradio as gr
|
|
| 18 |
import open_clip
|
| 19 |
from sam2.build_sam import build_sam2
|
| 20 |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
|
|
@@ -93,7 +108,7 @@ def inference_automatic(input_img, class_names):
|
|
| 93 |
@spaces.GPU
|
| 94 |
@torch.no_grad()
|
| 95 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 96 |
-
def inference_point(input_img,
|
| 97 |
|
| 98 |
|
| 99 |
mp.set_start_method("spawn", force=True)
|
|
@@ -106,8 +121,20 @@ def inference_point(input_img, img_state,):
|
|
| 106 |
|
| 107 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
return visualized_output
|
| 112 |
|
| 113 |
|
|
@@ -136,8 +163,20 @@ def inference_box(input_img, img_state,):
|
|
| 136 |
|
| 137 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
return visualized_output
|
| 142 |
|
| 143 |
|
|
@@ -234,7 +273,7 @@ def preprocess_example(input_img, img_state):
|
|
| 234 |
|
| 235 |
def clear_everything(img_state):
|
| 236 |
img_state.clear()
|
| 237 |
-
return img_state, None, None
|
| 238 |
|
| 239 |
|
| 240 |
def clean_prompts(img_state):
|
|
@@ -296,7 +335,7 @@ with gr.Blocks() as demo:
|
|
| 296 |
output_image = gr.Image(type="pil", label='Segmentation Map')
|
| 297 |
|
| 298 |
# Buttons below segmentation map (now placed under segmentation map)
|
| 299 |
-
run_button = gr.Button("Run Automatic Segmentation")
|
| 300 |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
|
| 301 |
|
| 302 |
clear_button = gr.Button("Clear")
|
|
@@ -310,9 +349,12 @@ with gr.Blocks() as demo:
|
|
| 310 |
with gr.Row(): # 水平排列
|
| 311 |
with gr.Column(scale=1):
|
| 312 |
input_image = gr.Image( label="Input Image", type="pil")
|
| 313 |
-
|
|
|
|
| 314 |
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
|
| 317 |
|
| 318 |
input_image.select(
|
|
@@ -321,30 +363,31 @@ with gr.Blocks() as demo:
|
|
| 321 |
outputs=[img_state_bbox, input_image]
|
| 322 |
).then(
|
| 323 |
inference_box,
|
| 324 |
-
inputs=[input_image, img_state_bbox],
|
| 325 |
outputs=[output_image_box]
|
| 326 |
)
|
| 327 |
-
|
|
|
|
| 328 |
clear_prompt_button_box.click(
|
| 329 |
clean_prompts,
|
| 330 |
inputs=[img_state_bbox],
|
| 331 |
outputs=[img_state_bbox, input_image, output_image_box]
|
| 332 |
)
|
| 333 |
-
|
| 334 |
clear_button_box.click(
|
| 335 |
clear_everything,
|
| 336 |
inputs=[img_state_bbox],
|
| 337 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
| 338 |
)
|
| 339 |
input_image.clear(
|
| 340 |
clear_everything,
|
| 341 |
inputs=[img_state_bbox],
|
| 342 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
| 343 |
)
|
| 344 |
output_image_box.clear(
|
| 345 |
clear_everything,
|
| 346 |
inputs=[img_state_bbox],
|
| 347 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
| 348 |
)
|
| 349 |
|
| 350 |
|
|
@@ -363,44 +406,41 @@ with gr.Blocks() as demo:
|
|
| 363 |
with gr.Row(): # 水平排列
|
| 364 |
with gr.Column(scale=1):
|
| 365 |
input_image = gr.Image( label="Input Image", type="pil")
|
| 366 |
-
|
|
|
|
| 367 |
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
| 368 |
-
|
|
|
|
|
|
|
| 369 |
input_image.select(
|
| 370 |
get_points_with_draw,
|
| 371 |
[input_image, img_state_points],
|
| 372 |
outputs=[img_state_points, input_image]
|
| 373 |
).then(
|
| 374 |
inference_point,
|
| 375 |
-
inputs=[input_image, img_state_points],
|
| 376 |
outputs=[output_image_point]
|
| 377 |
)
|
| 378 |
-
clear_prompt_button_point = gr.Button("Clean Prompt")
|
| 379 |
clear_prompt_button_point.click(
|
| 380 |
clean_prompts,
|
| 381 |
inputs=[img_state_points],
|
| 382 |
outputs=[img_state_points, input_image, output_image_point]
|
| 383 |
)
|
| 384 |
-
clear_button_point = gr.Button("Restart")
|
| 385 |
clear_button_point.click(
|
| 386 |
clear_everything,
|
| 387 |
inputs=[img_state_points],
|
| 388 |
-
outputs=[img_state_points, input_image, output_image_point]
|
| 389 |
)
|
| 390 |
input_image.clear(
|
| 391 |
clear_everything,
|
| 392 |
inputs=[img_state_points],
|
| 393 |
-
outputs=[img_state_points, input_image, output_image_point]
|
| 394 |
)
|
| 395 |
output_image_point.clear(
|
| 396 |
clear_everything,
|
| 397 |
inputs=[img_state_points],
|
| 398 |
-
outputs=[img_state_points, input_image, output_image_point]
|
| 399 |
)
|
| 400 |
-
def clear_and_set_example_point(example):
|
| 401 |
-
clear_everything(img_state_points)
|
| 402 |
-
return example
|
| 403 |
-
|
| 404 |
gr.Examples(
|
| 405 |
examples=examples_point,
|
| 406 |
inputs=[input_image, img_state_points],
|
|
|
|
| 1 |
+
## Some code was modified from Ovseg and OV-Sam.Thanks to their excellent work.
|
| 2 |
+
## Ovseg Code:https://github.com/facebookresearch/ov-seg
|
| 3 |
+
## OV-Sam Code:https://github.com/HarborYuan/ovsam
|
| 4 |
+
|
| 5 |
import spaces
|
| 6 |
import multiprocessing as mp
|
| 7 |
import numpy as np
|
|
|
|
| 22 |
import open_clip
|
| 23 |
from sam2.build_sam import build_sam2
|
| 24 |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
|
| 25 |
+
from mask_adapter.data.datasets import openseg_classes
|
| 26 |
+
|
| 27 |
+
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
| 28 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
| 29 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
| 30 |
+
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
|
| 31 |
+
ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
|
| 32 |
+
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
|
| 33 |
+
class_names_coco_ade20k = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
|
| 34 |
+
|
| 35 |
+
|
| 36 |
|
| 37 |
|
| 38 |
|
|
|
|
| 108 |
@spaces.GPU
|
| 109 |
@torch.no_grad()
|
| 110 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
| 111 |
+
def inference_point(input_img, img_state,class_names_input):
|
| 112 |
|
| 113 |
|
| 114 |
mp.set_start_method("spawn", force=True)
|
|
|
|
| 121 |
|
| 122 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
| 123 |
|
| 124 |
+
if not class_names_input:
|
| 125 |
+
class_names_input = class_names_coco_ade20k
|
| 126 |
+
|
| 127 |
+
if class_names_input == class_names_coco_ade20k:
|
| 128 |
+
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
|
| 129 |
+
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features)
|
| 130 |
+
else:
|
| 131 |
+
class_names_input = class_names_input.split(',')
|
| 132 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
|
| 133 |
+
text = open_clip.tokenize(txts)
|
| 134 |
+
text_features = clip_model.encode_text(text.cuda())
|
| 135 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 136 |
+
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features,class_names_input)
|
| 137 |
+
|
| 138 |
return visualized_output
|
| 139 |
|
| 140 |
|
|
|
|
| 163 |
|
| 164 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
| 165 |
|
| 166 |
+
if not class_names_input:
|
| 167 |
+
class_names_input = class_names_coco_ade20k
|
| 168 |
+
|
| 169 |
+
if class_names_input == class_names_coco_ade20k:
|
| 170 |
+
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
|
| 171 |
+
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
|
| 172 |
+
else:
|
| 173 |
+
class_names_input = class_names_input.split(',')
|
| 174 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
|
| 175 |
+
text = open_clip.tokenize(txts)
|
| 176 |
+
text_features = clip_model.encode_text(text.cuda())
|
| 177 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
| 178 |
+
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features,class_names_input)
|
| 179 |
+
|
| 180 |
return visualized_output
|
| 181 |
|
| 182 |
|
|
|
|
| 273 |
|
| 274 |
def clear_everything(img_state):
|
| 275 |
img_state.clear()
|
| 276 |
+
return img_state, None, None, gr.Textbox(value='',lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
| 277 |
|
| 278 |
|
| 279 |
def clean_prompts(img_state):
|
|
|
|
| 335 |
output_image = gr.Image(type="pil", label='Segmentation Map')
|
| 336 |
|
| 337 |
# Buttons below segmentation map (now placed under segmentation map)
|
| 338 |
+
run_button = gr.Button("Run Automatic Segmentation", elem_id="run_button",variant='primary')
|
| 339 |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
|
| 340 |
|
| 341 |
clear_button = gr.Button("Clear")
|
|
|
|
| 349 |
with gr.Row(): # 水平排列
|
| 350 |
with gr.Column(scale=1):
|
| 351 |
input_image = gr.Image( label="Input Image", type="pil")
|
| 352 |
+
class_names_input_box = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
| 353 |
+
with gr.Column(scale=1):
|
| 354 |
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
| 355 |
+
clear_prompt_button_box = gr.Button("Clean Prompt")
|
| 356 |
+
clear_button_box = gr.Button("Restart")
|
| 357 |
+
|
| 358 |
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
|
| 359 |
|
| 360 |
input_image.select(
|
|
|
|
| 363 |
outputs=[img_state_bbox, input_image]
|
| 364 |
).then(
|
| 365 |
inference_box,
|
| 366 |
+
inputs=[input_image, img_state_bbox,class_names_input_box],
|
| 367 |
outputs=[output_image_box]
|
| 368 |
)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
clear_prompt_button_box.click(
|
| 372 |
clean_prompts,
|
| 373 |
inputs=[img_state_bbox],
|
| 374 |
outputs=[img_state_bbox, input_image, output_image_box]
|
| 375 |
)
|
| 376 |
+
|
| 377 |
clear_button_box.click(
|
| 378 |
clear_everything,
|
| 379 |
inputs=[img_state_bbox],
|
| 380 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
| 381 |
)
|
| 382 |
input_image.clear(
|
| 383 |
clear_everything,
|
| 384 |
inputs=[img_state_bbox],
|
| 385 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
| 386 |
)
|
| 387 |
output_image_box.clear(
|
| 388 |
clear_everything,
|
| 389 |
inputs=[img_state_bbox],
|
| 390 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
| 391 |
)
|
| 392 |
|
| 393 |
|
|
|
|
| 406 |
with gr.Row(): # 水平排列
|
| 407 |
with gr.Column(scale=1):
|
| 408 |
input_image = gr.Image( label="Input Image", type="pil")
|
| 409 |
+
class_names_input_point = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
| 410 |
+
with gr.Column(scale=1):
|
| 411 |
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
| 412 |
+
clear_prompt_button_point = gr.Button("Clean Prompt")
|
| 413 |
+
clear_button_point = gr.Button("Restart")
|
| 414 |
+
|
| 415 |
input_image.select(
|
| 416 |
get_points_with_draw,
|
| 417 |
[input_image, img_state_points],
|
| 418 |
outputs=[img_state_points, input_image]
|
| 419 |
).then(
|
| 420 |
inference_point,
|
| 421 |
+
inputs=[input_image, img_state_points,class_names_input_point],
|
| 422 |
outputs=[output_image_point]
|
| 423 |
)
|
|
|
|
| 424 |
clear_prompt_button_point.click(
|
| 425 |
clean_prompts,
|
| 426 |
inputs=[img_state_points],
|
| 427 |
outputs=[img_state_points, input_image, output_image_point]
|
| 428 |
)
|
|
|
|
| 429 |
clear_button_point.click(
|
| 430 |
clear_everything,
|
| 431 |
inputs=[img_state_points],
|
| 432 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
| 433 |
)
|
| 434 |
input_image.clear(
|
| 435 |
clear_everything,
|
| 436 |
inputs=[img_state_points],
|
| 437 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
| 438 |
)
|
| 439 |
output_image_point.clear(
|
| 440 |
clear_everything,
|
| 441 |
inputs=[img_state_points],
|
| 442 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
| 443 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
gr.Examples(
|
| 445 |
examples=examples_point,
|
| 446 |
inputs=[input_image, img_state_points],
|