Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime | |
| import torch | |
| from PIL import Image, ImageColor | |
| from torchvision.utils import draw_bounding_boxes | |
| import rfdetr.datasets.transforms as T | |
| from torchvision.ops import box_convert | |
| def _box_yxyx_to_xyxy(boxes: torch.Tensor) -> torch.Tensor: | |
| """Convert bounding boxes from (y1, x1, y2, x2) format to (x1, y1, x2, y2) format. | |
| Args: | |
| boxes (torch.Tensor): A tensor of bounding boxes in the (y1, x1, y2, x2) format. | |
| Returns: | |
| torch.Tensor: A tensor of bounding boxes in the (x1, y1, x2, y2) format. | |
| """ | |
| y1, x1, y2, x2 = boxes.unbind(-1) | |
| boxes = torch.stack((x1, y1, x2, y2), dim=-1) | |
| return boxes | |
| def _box_xyxy_to_yxyx(boxes: torch.Tensor) -> torch.Tensor: | |
| """Convert bounding boxes from (x1, y1, x2, y2) format to (y1, x1, y2, x2) format. | |
| Args: | |
| boxes (torch.Tensor): A tensor of bounding boxes in the (x1, y1, x2, y2) format. | |
| Returns: | |
| torch.Tensor: A tensor of bounding boxes in the (y1, x1, y2, x2) format. | |
| """ | |
| x1, y1, x2, y2 = boxes.unbind(-1) | |
| boxes = torch.stack((y1, x1, y2, x2), dim=-1) | |
| return boxes | |
| # Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py#L168 | |
| def extended_box_convert( | |
| boxes: torch.Tensor, in_fmt: str, out_fmt: str | |
| ) -> torch.Tensor: | |
| """ | |
| Converts boxes from given in_fmt to out_fmt. | |
| Supported in_fmt and out_fmt are: | |
| - 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is the format that torchvision utilities expect. | |
| - 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height. | |
| - 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height. | |
| - 'yxyx': boxes are represented via corners, y1, x1 being top left and y2, x2 being bottom right. This is the format that `amrcnn` model outputs. | |
| Args: | |
| boxes (Tensor[N, 4]): boxes which will be converted. | |
| in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx']. | |
| out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx']. | |
| Returns: | |
| Tensor[N, 4]: Boxes into converted format. | |
| """ | |
| if in_fmt == "yxyx": | |
| # Convert to xyxy and assign in_fmt accordingly | |
| boxes = _box_yxyx_to_xyxy(boxes) | |
| in_fmt = "xyxy" | |
| if out_fmt == "yxyx": | |
| # Convert to xyxy if not already in that format | |
| if in_fmt != "xyxy": | |
| boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt="xyxy") | |
| # Convert to yxyx | |
| boxes = _box_xyxy_to_yxyx(boxes) | |
| else: | |
| # Use torchvision's box_convert for other conversions | |
| boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt) | |
| return boxes | |
| def process_categories() -> tuple: | |
| with open("categories.json") as fp: | |
| categories = json.load(fp) | |
| category_id_to_name = {d["id"]: d["name"] for d in categories} | |
| random.seed(42) | |
| color_names = list(ImageColor.colormap.keys()) | |
| sampled_colors = random.sample(color_names, len(categories)) | |
| rgb_colors = [ImageColor.getrgb(color_name) for color_name in sampled_colors] | |
| category_id_to_color = {category["id"]: color for category, color in zip(categories, rgb_colors)} | |
| return category_id_to_name, category_id_to_color | |
| def draw_predictions(boxes, labels, scores, img, score_threshold=0.5, font_size=20): | |
| imgs_list = [] | |
| label_id_to_name, label_id_to_color = process_categories() | |
| mask = scores > score_threshold | |
| boxes_filtered = boxes[mask] | |
| labels_filtered = labels[mask] | |
| scores_filtered = scores[mask] | |
| label_names = [label_id_to_name[int(i)] for i in labels_filtered] | |
| colors = [label_id_to_color[int(i)] for i in labels_filtered] | |
| img_bbox = draw_bounding_boxes( | |
| img, | |
| boxes=boxes_filtered, | |
| labels=[f"{name}: {score:.2f}" for name, score in zip(label_names, scores_filtered)], | |
| colors=colors, | |
| width=10, | |
| font_size=30, | |
| font="Kilikia.ttf", | |
| ) | |
| imgs_list.append(img_bbox.permute(1, 2, 0).numpy()) # convert to HWC for Gradio | |
| return imgs_list | |
| def inference(image_path, model_name, bbox_threshold): | |
| transforms = T.Compose([ | |
| T.SquareResize([1120]), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(image_path).convert("RGB") | |
| tensor_img, _ = transforms(image, None) | |
| tensor_img = tensor_img.unsqueeze(0) | |
| print(model_name) | |
| if model_name == "RF-DETR-B": | |
| model_path = "rfdetr.onnx" | |
| if model_name == "RF-DETR-L": | |
| model_path = "rfdetrl.onnx" | |
| sess_options = onnxruntime.SessionOptions() | |
| sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL | |
| ort_session = onnxruntime.InferenceSession( | |
| model_path, | |
| providers=["CPUExecutionProvider"], | |
| sess_options=sess_options | |
| ) | |
| ort_inputs = {ort_session.get_inputs()[0].name: tensor_img.cpu().numpy()} | |
| pred_boxes, logits = ort_session.run(['dets', 'labels'], ort_inputs) | |
| print(pred_boxes) | |
| scores = torch.sigmoid(torch.from_numpy(logits)) | |
| max_scores, pred_labels = scores.max(-1) | |
| mask = max_scores > bbox_threshold | |
| pred_boxes = torch.from_numpy(pred_boxes[0]) | |
| image_w, image_h = image.size | |
| pred_boxes_abs = pred_boxes.clone() | |
| pred_boxes_abs[:, 0] *= image_w | |
| pred_boxes_abs[:, 1] *= image_h | |
| pred_boxes_abs[:, 2] *= image_w | |
| pred_boxes_abs[:, 3] *= image_h | |
| mask = mask.squeeze(0) | |
| filtered_boxes = extended_box_convert( | |
| pred_boxes_abs[mask], in_fmt="cxcywh", out_fmt="xyxy" | |
| ) | |
| filtered_scores = max_scores.squeeze(0)[mask] | |
| filtered_labels = pred_labels.squeeze(0)[mask] | |
| img_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1) | |
| print("drawing") | |
| return draw_predictions(filtered_boxes, filtered_labels, filtered_scores, img_tensor, score_threshold=bbox_threshold) | |
| title = "FashionUnveil - Demo" | |
| description = r"""This is the demo of the research project <a href="https://github.com/DatSplit/FashionVeil">FashionUnveil</a>. Upload your image for inference.""" | |
| demo = gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Input Image"), | |
| gr.Dropdown(["RF-DETR-L", "RF-DETR-B"], value="RF-DETR-B", label="Model"), | |
| gr.Slider(value=0.5, minimum=0.0, maximum=0.9, step=0.05, label="BBox threshold"), | |
| ], | |
| outputs=gr.Gallery(label="Output", preview=True, height=500), | |
| title=title, | |
| description=description, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |