FashionVeil / app.py
DatSplit's picture
Update app.py
f704017 verified
raw
history blame
6.85 kB
import json
import random
import spaces
import gradio as gr
import numpy as np
import onnxruntime
import torch
from PIL import Image, ImageColor
from torchvision.utils import draw_bounding_boxes
import rfdetr.datasets.transforms as T
from torchvision.ops import box_convert
def _box_yxyx_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
"""Convert bounding boxes from (y1, x1, y2, x2) format to (x1, y1, x2, y2) format.
Args:
boxes (torch.Tensor): A tensor of bounding boxes in the (y1, x1, y2, x2) format.
Returns:
torch.Tensor: A tensor of bounding boxes in the (x1, y1, x2, y2) format.
"""
y1, x1, y2, x2 = boxes.unbind(-1)
boxes = torch.stack((x1, y1, x2, y2), dim=-1)
return boxes
def _box_xyxy_to_yxyx(boxes: torch.Tensor) -> torch.Tensor:
"""Convert bounding boxes from (x1, y1, x2, y2) format to (y1, x1, y2, x2) format.
Args:
boxes (torch.Tensor): A tensor of bounding boxes in the (x1, y1, x2, y2) format.
Returns:
torch.Tensor: A tensor of bounding boxes in the (y1, x1, y2, x2) format.
"""
x1, y1, x2, y2 = boxes.unbind(-1)
boxes = torch.stack((y1, x1, y2, x2), dim=-1)
return boxes
# Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py#L168
def extended_box_convert(
boxes: torch.Tensor, in_fmt: str, out_fmt: str
) -> torch.Tensor:
"""
Converts boxes from given in_fmt to out_fmt.
Supported in_fmt and out_fmt are:
- 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is the format that torchvision utilities expect.
- 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
- 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height.
- 'yxyx': boxes are represented via corners, y1, x1 being top left and y2, x2 being bottom right. This is the format that `amrcnn` model outputs.
Args:
boxes (Tensor[N, 4]): boxes which will be converted.
in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].
out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].
Returns:
Tensor[N, 4]: Boxes into converted format.
"""
if in_fmt == "yxyx":
# Convert to xyxy and assign in_fmt accordingly
boxes = _box_yxyx_to_xyxy(boxes)
in_fmt = "xyxy"
if out_fmt == "yxyx":
# Convert to xyxy if not already in that format
if in_fmt != "xyxy":
boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt="xyxy")
# Convert to yxyx
boxes = _box_xyxy_to_yxyx(boxes)
else:
# Use torchvision's box_convert for other conversions
boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt)
return boxes
def process_categories() -> tuple:
with open("categories.json") as fp:
categories = json.load(fp)
category_id_to_name = {d["id"]: d["name"] for d in categories}
random.seed(42)
color_names = list(ImageColor.colormap.keys())
sampled_colors = random.sample(color_names, len(categories))
rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors]
category_id_to_color = {category["id"]: color for category, color in zip(categories, rgb_colors)}
return category_id_to_name, category_id_to_color
def draw_predictions(boxes, labels, scores, img, score_threshold=0.5, font_size=20):
imgs_list = []
label_id_to_name, label_id_to_color = process_categories()
mask = scores > score_threshold
boxes_filtered = boxes[mask]
labels_filtered = labels[mask]
scores_filtered = scores[mask]
label_names = [label_id_to_name[int(i)] for i in labels_filtered]
colors = [label_id_to_color[int(i)] for i in labels_filtered]
img_bbox = draw_bounding_boxes(
img,
boxes=boxes_filtered,
labels=[f"{name}: {score:.2f}" for name, score in zip(label_names, scores_filtered)],
colors=colors,
width=10,
font_size=30,
font="Kilikia.ttf",
)
imgs_list.append(img_bbox.permute(1, 2, 0).numpy()) # convert to HWC for Gradio
return imgs_list
def inference(image_path, model_name, bbox_threshold):
transforms = T.Compose([
T.SquareResize([1120]),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert("RGB")
tensor_img, _ = transforms(image, None)
tensor_img = tensor_img.unsqueeze(0)
print(model_name)
if model_name == "RF-DETR-B":
model_path = "rfdetr.onnx"
if model_name == "RF-DETR-L":
model_path = "rfdetrl.onnx"
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_session = onnxruntime.InferenceSession(
model_path,
providers=["CPUExecutionProvider"],
sess_options=sess_options
)
ort_inputs = {ort_session.get_inputs()[0].name: tensor_img.cpu().numpy()}
pred_boxes, logits = ort_session.run(['dets', 'labels'], ort_inputs)
print(pred_boxes)
scores = torch.sigmoid(torch.from_numpy(logits))
max_scores, pred_labels = scores.max(-1)
mask = max_scores > bbox_threshold
pred_boxes = torch.from_numpy(pred_boxes[0])
image_w, image_h = image.size
pred_boxes_abs = pred_boxes.clone()
pred_boxes_abs[:, 0] *= image_w
pred_boxes_abs[:, 1] *= image_h
pred_boxes_abs[:, 2] *= image_w
pred_boxes_abs[:, 3] *= image_h
mask = mask.squeeze(0)
filtered_boxes = extended_box_convert(
pred_boxes_abs[mask], in_fmt="cxcywh", out_fmt="xyxy"
)
filtered_scores = max_scores.squeeze(0)[mask]
filtered_labels = pred_labels.squeeze(0)[mask]
img_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1)
print("drawing")
return draw_predictions(filtered_boxes, filtered_labels, filtered_scores, img_tensor, score_threshold=bbox_threshold)
title = "FashionUnveil - Demo"
description = r"""This is the demo of the research project <a href="https://github.com/DatSplit/FashionVeil">FashionUnveil</a>. Upload your image for inference."""
demo = gr.Interface(
fn=inference,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(["RF-DETR-L", "RF-DETR-B"], value="RF-DETR-B", label="Model"),
gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold"),
],
outputs=gr.Gallery(label="Output", preview=True, height=500),
title=title,
description=description,
)
if __name__ == "__main__":
demo.launch()