# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # // # // Licensed under the Apache License, Version 2.0 (the "License"); # // you may not use this file except in compliance with the License. # // You may obtain a copy of the License at # // # // http://www.apache.org/licenses/LICENSE-2.0 # // # // Unless required by applicable law or agreed to in writing, software # // distributed under the License is distributed on an "AS IS" BASIS, # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # // See the License for the specific language governing permissions and # // limitations under the License. import math import random from typing import Union import torch from PIL import Image from torchvision.transforms import functional as TVF from torchvision.transforms.functional import InterpolationMode class AreaResize: def __init__( self, max_area: float, downsample_only: bool = False, interpolation: InterpolationMode = InterpolationMode.BICUBIC, ): self.max_area = max_area self.downsample_only = downsample_only self.interpolation = interpolation def __call__(self, image: Union[torch.Tensor, Image.Image]): if isinstance(image, torch.Tensor): height, width = image.shape[-2:] elif isinstance(image, Image.Image): width, height = image.size else: raise NotImplementedError scale = math.sqrt(self.max_area / (height * width)) # keep original height and width for small pictures. scale = 1 if scale >= 1 and self.downsample_only else scale resized_height, resized_width = round(height * scale), round(width * scale) return TVF.resize( image, size=(resized_height, resized_width), interpolation=self.interpolation, ) class AreaRandomCrop: def __init__( self, max_area: float, ): self.max_area = max_area def get_params(self, input_size, output_size): """Get parameters for ``crop`` for a random crop. Args: img (PIL Image): Image to be cropped. output_size (tuple): Expected output size of the crop. Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ # w, h = _get_image_size(img) h, w = input_size th, tw = output_size if w <= tw and h <= th: return 0, 0, h, w i = random.randint(0, h - th) j = random.randint(0, w - tw) return i, j, th, tw def __call__(self, image: Union[torch.Tensor, Image.Image]): if isinstance(image, torch.Tensor): height, width = image.shape[-2:] elif isinstance(image, Image.Image): width, height = image.size else: raise NotImplementedError resized_height = math.sqrt(self.max_area / (width / height)) resized_width = (width / height) * resized_height # print('>>>>>>>>>>>>>>>>>>>>>') # print((height, width)) # print( (resized_height, resized_width)) resized_height, resized_width = round(resized_height), round(resized_width) i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) image = TVF.crop(image, i, j, h, w) return image class ScaleResize: def __init__( self, scale: float, ): self.scale = scale def __call__(self, image: Union[torch.Tensor, Image.Image]): if isinstance(image, torch.Tensor): height, width = image.shape[-2:] interpolation_mode = InterpolationMode.BILINEAR antialias = True if image.ndim == 4 else "warn" elif isinstance(image, Image.Image): width, height = image.size interpolation_mode = InterpolationMode.LANCZOS antialias = "warn" else: raise NotImplementedError scale = self.scale # keep original height and width for small pictures resized_height, resized_width = round(height * scale), round(width * scale) image = TVF.resize( image, size=(resized_height, resized_width), interpolation=interpolation_mode, antialias=antialias, ) return image