Spaces:
Sleeping
Sleeping
| import math | |
| import pathlib | |
| import warnings | |
| from types import FunctionType | |
| from typing import Any, BinaryIO, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageColor, ImageDraw, ImageFont | |
| __all__ = [ | |
| "make_grid", | |
| "save_image", | |
| "draw_bounding_boxes", | |
| "draw_segmentation_masks", | |
| "draw_keypoints", | |
| "flow_to_image", | |
| ] | |
| def make_grid( | |
| tensor: Union[torch.Tensor, List[torch.Tensor]], | |
| nrow: int = 8, | |
| padding: int = 2, | |
| normalize: bool = False, | |
| value_range: Optional[Tuple[int, int]] = None, | |
| scale_each: bool = False, | |
| pad_value: float = 0.0, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| """ | |
| Make a grid of images. | |
| Args: | |
| tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) | |
| or a list of images all of the same size. | |
| nrow (int, optional): Number of images displayed in each row of the grid. | |
| The final grid size is ``(B / nrow, nrow)``. Default: ``8``. | |
| padding (int, optional): amount of padding. Default: ``2``. | |
| normalize (bool, optional): If True, shift the image to the range (0, 1), | |
| by the min and max values specified by ``value_range``. Default: ``False``. | |
| value_range (tuple, optional): tuple (min, max) where min and max are numbers, | |
| then these numbers are used to normalize the image. By default, min and max | |
| are computed from the tensor. | |
| range (tuple. optional): | |
| .. warning:: | |
| This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` | |
| instead. | |
| scale_each (bool, optional): If ``True``, scale each image in the batch of | |
| images separately rather than the (min, max) over all images. Default: ``False``. | |
| pad_value (float, optional): Value for the padded pixels. Default: ``0``. | |
| Returns: | |
| grid (Tensor): the tensor containing grid of images. | |
| """ | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| _log_api_usage_once(make_grid) | |
| if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): | |
| raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") | |
| if "range" in kwargs.keys(): | |
| warnings.warn( | |
| "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " | |
| "Please use 'value_range' instead." | |
| ) | |
| value_range = kwargs["range"] | |
| # if list of tensors, convert to a 4D mini-batch Tensor | |
| if isinstance(tensor, list): | |
| tensor = torch.stack(tensor, dim=0) | |
| if tensor.dim() == 2: # single image H x W | |
| tensor = tensor.unsqueeze(0) | |
| if tensor.dim() == 3: # single image | |
| if tensor.size(0) == 1: # if single-channel, convert to 3-channel | |
| tensor = torch.cat((tensor, tensor, tensor), 0) | |
| tensor = tensor.unsqueeze(0) | |
| if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images | |
| tensor = torch.cat((tensor, tensor, tensor), 1) | |
| if normalize is True: | |
| tensor = tensor.clone() # avoid modifying tensor in-place | |
| if value_range is not None: | |
| assert isinstance( | |
| value_range, tuple | |
| ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" | |
| def norm_ip(img, low, high): | |
| img.clamp_(min=low, max=high) | |
| img.sub_(low).div_(max(high - low, 1e-5)) | |
| def norm_range(t, value_range): | |
| if value_range is not None: | |
| norm_ip(t, value_range[0], value_range[1]) | |
| else: | |
| norm_ip(t, float(t.min()), float(t.max())) | |
| if scale_each is True: | |
| for t in tensor: # loop over mini-batch dimension | |
| norm_range(t, value_range) | |
| else: | |
| norm_range(tensor, value_range) | |
| assert isinstance(tensor, torch.Tensor) | |
| if tensor.size(0) == 1: | |
| return tensor.squeeze(0) | |
| # make the mini-batch of images into a grid | |
| nmaps = tensor.size(0) | |
| xmaps = min(nrow, nmaps) | |
| ymaps = int(math.ceil(float(nmaps) / xmaps)) | |
| height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) | |
| num_channels = tensor.size(1) | |
| grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) | |
| k = 0 | |
| for y in range(ymaps): | |
| for x in range(xmaps): | |
| if k >= nmaps: | |
| break | |
| # Tensor.copy_() is a valid method but seems to be missing from the stubs | |
| # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ | |
| grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] | |
| 2, x * width + padding, width - padding | |
| ).copy_(tensor[k]) | |
| k = k + 1 | |
| return grid | |
| def save_image( | |
| tensor: Union[torch.Tensor, List[torch.Tensor]], | |
| fp: Union[str, pathlib.Path, BinaryIO], | |
| format: Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| Save a given Tensor into an image file. | |
| Args: | |
| tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, | |
| saves the tensor as a grid of images by calling ``make_grid``. | |
| fp (string or file object): A filename or a file object | |
| format(Optional): If omitted, the format to use is determined from the filename extension. | |
| If a file object was used instead of a filename, this parameter should always be used. | |
| **kwargs: Other arguments are documented in ``make_grid``. | |
| """ | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| _log_api_usage_once(save_image) | |
| grid = make_grid(tensor, **kwargs) | |
| # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer | |
| ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
| im = Image.fromarray(ndarr) | |
| im.save(fp, format=format) | |
| def draw_bounding_boxes( | |
| image: torch.Tensor, | |
| boxes: torch.Tensor, | |
| labels: Optional[List[str]] = None, | |
| colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, | |
| fill: Optional[bool] = False, | |
| width: int = 1, | |
| font: Optional[str] = None, | |
| font_size: int = 10, | |
| ) -> torch.Tensor: | |
| """ | |
| Draws bounding boxes on given image. | |
| The values of the input image should be uint8 between 0 and 255. | |
| If fill is True, Resulting Tensor should be saved as PNG image. | |
| Args: | |
| image (Tensor): Tensor of shape (C x H x W) and dtype uint8. | |
| boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that | |
| the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and | |
| `0 <= ymin < ymax < H`. | |
| labels (List[str]): List containing the labels of bounding boxes. | |
| colors (color or list of colors, optional): List containing the colors | |
| of the boxes or single color for all boxes. The color can be represented as | |
| PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
| By default, random colors are generated for boxes. | |
| fill (bool): If `True` fills the bounding box with specified color. | |
| width (int): Width of bounding box. | |
| font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may | |
| also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, | |
| `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. | |
| font_size (int): The requested font size in points. | |
| Returns: | |
| img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. | |
| """ | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| _log_api_usage_once(draw_bounding_boxes) | |
| if not isinstance(image, torch.Tensor): | |
| raise TypeError(f"Tensor expected, got {type(image)}") | |
| elif image.dtype != torch.uint8: | |
| raise ValueError(f"Tensor uint8 expected, got {image.dtype}") | |
| elif image.dim() != 3: | |
| raise ValueError("Pass individual images, not batches") | |
| elif image.size(0) not in {1, 3}: | |
| raise ValueError("Only grayscale and RGB images are supported") | |
| num_boxes = boxes.shape[0] | |
| if labels is None: | |
| labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] | |
| elif len(labels) != num_boxes: | |
| raise ValueError( | |
| f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." | |
| ) | |
| if colors is None: | |
| colors = _generate_color_palette(num_boxes) | |
| elif isinstance(colors, list): | |
| if len(colors) < num_boxes: | |
| raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") | |
| else: # colors specifies a single color for all boxes | |
| colors = [colors] * num_boxes | |
| colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] | |
| # Handle Grayscale images | |
| if image.size(0) == 1: | |
| image = torch.tile(image, (3, 1, 1)) | |
| ndarr = image.permute(1, 2, 0).cpu().numpy() | |
| img_to_draw = Image.fromarray(ndarr) | |
| img_boxes = boxes.to(torch.int64).tolist() | |
| if fill: | |
| draw = ImageDraw.Draw(img_to_draw, "RGBA") | |
| else: | |
| draw = ImageDraw.Draw(img_to_draw) | |
| txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) | |
| for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] | |
| if fill: | |
| fill_color = color + (100,) | |
| draw.rectangle(bbox, width=width, outline=color, fill=fill_color) | |
| else: | |
| draw.rectangle(bbox, width=width, outline=color) | |
| if label is not None: | |
| margin = width + 1 | |
| draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) | |
| return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |
| def draw_segmentation_masks( | |
| image: torch.Tensor, | |
| masks: torch.Tensor, | |
| alpha: float = 0.8, | |
| colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Draws segmentation masks on given RGB image. | |
| The values of the input image should be uint8 between 0 and 255. | |
| Args: | |
| image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
| masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. | |
| alpha (float): Float number between 0 and 1 denoting the transparency of the masks. | |
| 0 means full transparency, 1 means no transparency. | |
| colors (color or list of colors, optional): List containing the colors | |
| of the masks or single color for all masks. The color can be represented as | |
| PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
| By default, random colors are generated for each mask. | |
| Returns: | |
| img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. | |
| """ | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| _log_api_usage_once(draw_segmentation_masks) | |
| if not isinstance(image, torch.Tensor): | |
| raise TypeError(f"The image must be a tensor, got {type(image)}") | |
| elif image.dtype != torch.uint8: | |
| raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
| elif image.dim() != 3: | |
| raise ValueError("Pass individual images, not batches") | |
| elif image.size()[0] != 3: | |
| raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
| if masks.ndim == 2: | |
| masks = masks[None, :, :] | |
| if masks.ndim != 3: | |
| raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") | |
| if masks.dtype != torch.bool: | |
| raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") | |
| if masks.shape[-2:] != image.shape[-2:]: | |
| raise ValueError("The image and the masks must have the same height and width") | |
| num_masks = masks.size()[0] | |
| if colors is not None and num_masks > len(colors): | |
| raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") | |
| if colors is None: | |
| colors = _generate_color_palette(num_masks) | |
| if not isinstance(colors, list): | |
| colors = [colors] | |
| if not isinstance(colors[0], (tuple, str)): | |
| raise ValueError("colors must be a tuple or a string, or a list thereof") | |
| if isinstance(colors[0], tuple) and len(colors[0]) != 3: | |
| raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") | |
| out_dtype = torch.uint8 | |
| colors_ = [] | |
| for color in colors: | |
| if isinstance(color, str): | |
| color = ImageColor.getrgb(color) | |
| colors_.append(torch.tensor(color, dtype=out_dtype)) | |
| img_to_draw = image.detach().clone() | |
| # TODO: There might be a way to vectorize this | |
| for mask, color in zip(masks, colors_): | |
| img_to_draw[:, mask] = color[:, None] | |
| out = image * (1 - alpha) + img_to_draw * alpha | |
| return out.to(out_dtype) | |
| def draw_keypoints( | |
| image: torch.Tensor, | |
| keypoints: torch.Tensor, | |
| connectivity: Optional[List[Tuple[int, int]]] = None, | |
| colors: Optional[Union[str, Tuple[int, int, int]]] = None, | |
| radius: int = 2, | |
| width: int = 3, | |
| ) -> torch.Tensor: | |
| """ | |
| Draws Keypoints on given RGB image. | |
| The values of the input image should be uint8 between 0 and 255. | |
| Args: | |
| image (Tensor): Tensor of shape (3, H, W) and dtype uint8. | |
| keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, | |
| in the format [x, y]. | |
| connectivity (List[Tuple[int, int]]]): A List of tuple where, | |
| each tuple contains pair of keypoints to be connected. | |
| colors (str, Tuple): The color can be represented as | |
| PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. | |
| radius (int): Integer denoting radius of keypoint. | |
| width (int): Integer denoting width of line connecting keypoints. | |
| Returns: | |
| img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. | |
| """ | |
| if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | |
| _log_api_usage_once(draw_keypoints) | |
| if not isinstance(image, torch.Tensor): | |
| raise TypeError(f"The image must be a tensor, got {type(image)}") | |
| elif image.dtype != torch.uint8: | |
| raise ValueError(f"The image dtype must be uint8, got {image.dtype}") | |
| elif image.dim() != 3: | |
| raise ValueError("Pass individual images, not batches") | |
| elif image.size()[0] != 3: | |
| raise ValueError("Pass an RGB image. Other Image formats are not supported") | |
| if keypoints.ndim != 3: | |
| raise ValueError("keypoints must be of shape (num_instances, K, 2)") | |
| ndarr = image.permute(1, 2, 0).cpu().numpy() | |
| img_to_draw = Image.fromarray(ndarr) | |
| draw = ImageDraw.Draw(img_to_draw) | |
| img_kpts = keypoints.to(torch.int64).tolist() | |
| for kpt_id, kpt_inst in enumerate(img_kpts): | |
| for inst_id, kpt in enumerate(kpt_inst): | |
| x1 = kpt[0] - radius | |
| x2 = kpt[0] + radius | |
| y1 = kpt[1] - radius | |
| y2 = kpt[1] + radius | |
| draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) | |
| if connectivity: | |
| for connection in connectivity: | |
| start_pt_x = kpt_inst[connection[0]][0] | |
| start_pt_y = kpt_inst[connection[0]][1] | |
| end_pt_x = kpt_inst[connection[1]][0] | |
| end_pt_y = kpt_inst[connection[1]][1] | |
| draw.line( | |
| ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), | |
| width=width, | |
| ) | |
| return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) | |
| # Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization | |
| def flow_to_image(flow: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Converts a flow to an RGB image. | |
| Args: | |
| flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. | |
| Returns: | |
| img (Tensor): Image Tensor of dtype uint8 where each color corresponds | |
| to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. | |
| """ | |
| if flow.dtype != torch.float: | |
| raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") | |
| orig_shape = flow.shape | |
| if flow.ndim == 3: | |
| flow = flow[None] # Add batch dim | |
| if flow.ndim != 4 or flow.shape[1] != 2: | |
| raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") | |
| max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() | |
| epsilon = torch.finfo((flow).dtype).eps | |
| normalized_flow = flow / (max_norm + epsilon) | |
| img = _normalized_flow_to_image(normalized_flow) | |
| if len(orig_shape) == 3: | |
| img = img[0] # Remove batch dim | |
| return img | |
| def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Converts a batch of normalized flow to an RGB image. | |
| Args: | |
| normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) | |
| Returns: | |
| img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. | |
| """ | |
| N, _, H, W = normalized_flow.shape | |
| device = normalized_flow.device | |
| flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) | |
| colorwheel = _make_colorwheel().to(device) # shape [55x3] | |
| num_cols = colorwheel.shape[0] | |
| norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() | |
| a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi | |
| fk = (a + 1) / 2 * (num_cols - 1) | |
| k0 = torch.floor(fk).to(torch.long) | |
| k1 = k0 + 1 | |
| k1[k1 == num_cols] = 0 | |
| f = fk - k0 | |
| for c in range(colorwheel.shape[1]): | |
| tmp = colorwheel[:, c] | |
| col0 = tmp[k0] / 255.0 | |
| col1 = tmp[k1] / 255.0 | |
| col = (1 - f) * col0 + f * col1 | |
| col = 1 - norm * (1 - col) | |
| flow_image[:, c, :, :] = torch.floor(255 * col) | |
| return flow_image | |
| def _make_colorwheel() -> torch.Tensor: | |
| """ | |
| Generates a color wheel for optical flow visualization as presented in: | |
| Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) | |
| URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. | |
| Returns: | |
| colorwheel (Tensor[55, 3]): Colorwheel Tensor. | |
| """ | |
| RY = 15 | |
| YG = 6 | |
| GC = 4 | |
| CB = 11 | |
| BM = 13 | |
| MR = 6 | |
| ncols = RY + YG + GC + CB + BM + MR | |
| colorwheel = torch.zeros((ncols, 3)) | |
| col = 0 | |
| # RY | |
| colorwheel[0:RY, 0] = 255 | |
| colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) | |
| col = col + RY | |
| # YG | |
| colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) | |
| colorwheel[col : col + YG, 1] = 255 | |
| col = col + YG | |
| # GC | |
| colorwheel[col : col + GC, 1] = 255 | |
| colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) | |
| col = col + GC | |
| # CB | |
| colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) | |
| colorwheel[col : col + CB, 2] = 255 | |
| col = col + CB | |
| # BM | |
| colorwheel[col : col + BM, 2] = 255 | |
| colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) | |
| col = col + BM | |
| # MR | |
| colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) | |
| colorwheel[col : col + MR, 0] = 255 | |
| return colorwheel | |
| def _generate_color_palette(num_objects: int): | |
| palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) | |
| return [tuple((i * palette) % 255) for i in range(num_objects)] | |
| def _log_api_usage_once(obj: Any) -> None: | |
| """ | |
| Logs API usage(module and name) within an organization. | |
| In a large ecosystem, it's often useful to track the PyTorch and | |
| TorchVision APIs usage. This API provides the similar functionality to the | |
| logging module in the Python stdlib. It can be used for debugging purpose | |
| to log which methods are used and by default it is inactive, unless the user | |
| manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_. | |
| Please note it is triggered only once for the same API call within a process. | |
| It does not collect any data from open-source users since it is no-op by default. | |
| For more information, please refer to | |
| * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; | |
| * Logging policy: https://github.com/pytorch/vision/issues/5052; | |
| Args: | |
| obj (class instance or method): an object to extract info from. | |
| """ | |
| if not obj.__module__.startswith("torchvision"): | |
| return | |
| name = obj.__class__.__name__ | |
| if isinstance(obj, FunctionType): | |
| name = obj.__name__ | |
| torch._C._log_api_usage_once(f"{obj.__module__}.{name}") | |