Spaces:
Sleeping
Sleeping
| import math | |
| from collections import OrderedDict | |
| from functools import partial | |
| from typing import Any, Callable, List, NamedTuple, Optional | |
| import torch | |
| import torch.nn as nn | |
| # from .._internally_replaced_utils import load_state_dict_from_url | |
| from .vision_transformer_misc import ConvNormActivation | |
| from .vision_transformer_utils import _log_api_usage_once | |
| try: | |
| from torch.hub import load_state_dict_from_url | |
| except ImportError: | |
| from torch.utils.model_zoo import load_url as load_state_dict_from_url | |
| # __all__ = [ | |
| # "VisionTransformer", | |
| # "vit_b_16", | |
| # "vit_b_32", | |
| # "vit_l_16", | |
| # "vit_l_32", | |
| # ] | |
| model_urls = { | |
| "vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", | |
| "vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", | |
| "vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", | |
| "vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", | |
| } | |
| class ConvStemConfig(NamedTuple): | |
| out_channels: int | |
| kernel_size: int | |
| stride: int | |
| norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d | |
| activation_layer: Callable[..., nn.Module] = nn.ReLU | |
| class MLPBlock(nn.Sequential): | |
| """Transformer MLP block.""" | |
| def __init__(self, in_dim: int, mlp_dim: int, dropout: float): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(in_dim, mlp_dim) | |
| self.act = nn.GELU() | |
| self.dropout_1 = nn.Dropout(dropout) | |
| self.linear_2 = nn.Linear(mlp_dim, in_dim) | |
| self.dropout_2 = nn.Dropout(dropout) | |
| nn.init.xavier_uniform_(self.linear_1.weight) | |
| nn.init.xavier_uniform_(self.linear_2.weight) | |
| nn.init.normal_(self.linear_1.bias, std=1e-6) | |
| nn.init.normal_(self.linear_2.bias, std=1e-6) | |
| class EncoderBlock(nn.Module): | |
| """Transformer encoder block.""" | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| hidden_dim: int, | |
| mlp_dim: int, | |
| dropout: float, | |
| attention_dropout: float, | |
| norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| # Attention block | |
| self.ln_1 = norm_layer(hidden_dim) | |
| self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) | |
| self.dropout = nn.Dropout(dropout) | |
| # MLP block | |
| self.ln_2 = norm_layer(hidden_dim) | |
| self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) | |
| def forward(self, input: torch.Tensor): | |
| torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") | |
| x = self.ln_1(input) | |
| x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) | |
| x = self.dropout(x) | |
| x = x + input | |
| y = self.ln_2(x) | |
| y = self.mlp(y) | |
| return x + y | |
| class Encoder(nn.Module): | |
| """Transformer Model Encoder for sequence to sequence translation.""" | |
| def __init__( | |
| self, | |
| seq_length: int, | |
| num_layers: int, | |
| num_heads: int, | |
| hidden_dim: int, | |
| mlp_dim: int, | |
| dropout: float, | |
| attention_dropout: float, | |
| norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
| ): | |
| super().__init__() | |
| # Note that batch_size is on the first dim because | |
| # we have batch_first=True in nn.MultiAttention() by default | |
| self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT | |
| self.dropout = nn.Dropout(dropout) | |
| layers: OrderedDict[str, nn.Module] = OrderedDict() | |
| for i in range(num_layers): | |
| layers[f"encoder_layer_{i}"] = EncoderBlock( | |
| num_heads, | |
| hidden_dim, | |
| mlp_dim, | |
| dropout, | |
| attention_dropout, | |
| norm_layer, | |
| ) | |
| self.layers = nn.Sequential(layers) | |
| self.ln = norm_layer(hidden_dim) | |
| def forward(self, input: torch.Tensor): | |
| torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") | |
| input = input + self.pos_embedding | |
| return self.ln(self.layers(self.dropout(input))) | |
| class VisionTransformer(nn.Module): | |
| """Vision Transformer as per https://arxiv.org/abs/2010.11929.""" | |
| def __init__( | |
| self, | |
| image_size: int, | |
| patch_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| hidden_dim: int, | |
| mlp_dim: int, | |
| dropout: float = 0.0, | |
| attention_dropout: float = 0.0, | |
| num_classes: int = 1000, | |
| representation_size: Optional[int] = None, | |
| norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), | |
| conv_stem_configs: Optional[List[ConvStemConfig]] = None, | |
| ): | |
| super().__init__() | |
| _log_api_usage_once(self) | |
| torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!") | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.hidden_dim = hidden_dim | |
| self.mlp_dim = mlp_dim | |
| self.attention_dropout = attention_dropout | |
| self.dropout = dropout | |
| self.num_classes = num_classes | |
| self.representation_size = representation_size | |
| self.norm_layer = norm_layer | |
| if conv_stem_configs is not None: | |
| # As per https://arxiv.org/abs/2106.14881 | |
| seq_proj = nn.Sequential() | |
| prev_channels = 3 | |
| for i, conv_stem_layer_config in enumerate(conv_stem_configs): | |
| seq_proj.add_module( | |
| f"conv_bn_relu_{i}", | |
| ConvNormActivation( | |
| in_channels=prev_channels, | |
| out_channels=conv_stem_layer_config.out_channels, | |
| kernel_size=conv_stem_layer_config.kernel_size, | |
| stride=conv_stem_layer_config.stride, | |
| norm_layer=conv_stem_layer_config.norm_layer, | |
| activation_layer=conv_stem_layer_config.activation_layer, | |
| ), | |
| ) | |
| prev_channels = conv_stem_layer_config.out_channels | |
| seq_proj.add_module( | |
| "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1) | |
| ) | |
| self.conv_proj: nn.Module = seq_proj | |
| else: | |
| self.conv_proj = nn.Conv2d( | |
| in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size | |
| ) | |
| seq_length = (image_size // patch_size) ** 2 | |
| # Add a class token | |
| self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) | |
| seq_length += 1 | |
| self.encoder = Encoder( | |
| seq_length, | |
| num_layers, | |
| num_heads, | |
| hidden_dim, | |
| mlp_dim, | |
| dropout, | |
| attention_dropout, | |
| norm_layer, | |
| ) | |
| self.seq_length = seq_length | |
| heads_layers: OrderedDict[str, nn.Module] = OrderedDict() | |
| if representation_size is None: | |
| heads_layers["head"] = nn.Linear(hidden_dim, num_classes) | |
| else: | |
| heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) | |
| heads_layers["act"] = nn.Tanh() | |
| heads_layers["head"] = nn.Linear(representation_size, num_classes) | |
| self.heads = nn.Sequential(heads_layers) | |
| if isinstance(self.conv_proj, nn.Conv2d): | |
| # Init the patchify stem | |
| fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1] | |
| nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in)) | |
| if self.conv_proj.bias is not None: | |
| nn.init.zeros_(self.conv_proj.bias) | |
| elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d): | |
| # Init the last 1x1 conv of the conv stem | |
| nn.init.normal_( | |
| self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels) | |
| ) | |
| if self.conv_proj.conv_last.bias is not None: | |
| nn.init.zeros_(self.conv_proj.conv_last.bias) | |
| if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): | |
| fan_in = self.heads.pre_logits.in_features | |
| nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) | |
| nn.init.zeros_(self.heads.pre_logits.bias) | |
| if isinstance(self.heads.head, nn.Linear): | |
| nn.init.zeros_(self.heads.head.weight) | |
| nn.init.zeros_(self.heads.head.bias) | |
| def _process_input(self, x: torch.Tensor) -> torch.Tensor: | |
| n, c, h, w = x.shape | |
| p = self.patch_size | |
| torch._assert(h == self.image_size, "Wrong image height!") | |
| torch._assert(w == self.image_size, "Wrong image width!") | |
| n_h = h // p | |
| n_w = w // p | |
| # (n, c, h, w) -> (n, hidden_dim, n_h, n_w) | |
| x = self.conv_proj(x) | |
| # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)) | |
| x = x.reshape(n, self.hidden_dim, n_h * n_w) | |
| # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim) | |
| # The self attention layer expects inputs in the format (N, S, E) | |
| # where S is the source sequence length, N is the batch size, E is the | |
| # embedding dimension | |
| x = x.permute(0, 2, 1) | |
| return x | |
| def forward(self, x: torch.Tensor): | |
| out = {} | |
| # Reshape and permute the input tensor | |
| x = self._process_input(x) | |
| n = x.shape[0] | |
| # Expand the class token to the full batch | |
| batch_class_token = self.class_token.expand(n, -1, -1) | |
| x = torch.cat([batch_class_token, x], dim=1) | |
| x = self.encoder(x) | |
| img_feature = x[:,1:] | |
| H = W = int(self.image_size / self.patch_size) | |
| out['f4'] = img_feature.view(n, H, W, self.hidden_dim).permute(0,3,1,2) | |
| # Classifier "token" as used by standard language architectures | |
| x = x[:, 0] | |
| out['penultimate'] = x | |
| x = self.heads(x) # I checked that for all pretrained ViT, this is just a fc | |
| out['logits'] = x | |
| return out | |
| def _vision_transformer( | |
| arch: str, | |
| patch_size: int, | |
| num_layers: int, | |
| num_heads: int, | |
| hidden_dim: int, | |
| mlp_dim: int, | |
| pretrained: bool, | |
| progress: bool, | |
| **kwargs: Any, | |
| ) -> VisionTransformer: | |
| image_size = kwargs.pop("image_size", 224) | |
| model = VisionTransformer( | |
| image_size=image_size, | |
| patch_size=patch_size, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| hidden_dim=hidden_dim, | |
| mlp_dim=mlp_dim, | |
| **kwargs, | |
| ) | |
| if pretrained: | |
| if arch not in model_urls: | |
| raise ValueError(f"No checkpoint is available for model type '{arch}'!") | |
| state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def vit_b_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | |
| """ | |
| Constructs a vit_b_16 architecture from | |
| `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return _vision_transformer( | |
| arch="vit_b_16", | |
| patch_size=16, | |
| num_layers=12, | |
| num_heads=12, | |
| hidden_dim=768, | |
| mlp_dim=3072, | |
| pretrained=pretrained, | |
| progress=progress, | |
| **kwargs, | |
| ) | |
| def vit_b_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | |
| """ | |
| Constructs a vit_b_32 architecture from | |
| `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return _vision_transformer( | |
| arch="vit_b_32", | |
| patch_size=32, | |
| num_layers=12, | |
| num_heads=12, | |
| hidden_dim=768, | |
| mlp_dim=3072, | |
| pretrained=pretrained, | |
| progress=progress, | |
| **kwargs, | |
| ) | |
| def vit_l_16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | |
| """ | |
| Constructs a vit_l_16 architecture from | |
| `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return _vision_transformer( | |
| arch="vit_l_16", | |
| patch_size=16, | |
| num_layers=24, | |
| num_heads=16, | |
| hidden_dim=1024, | |
| mlp_dim=4096, | |
| pretrained=pretrained, | |
| progress=progress, | |
| **kwargs, | |
| ) | |
| def vit_l_32(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VisionTransformer: | |
| """ | |
| Constructs a vit_l_32 architecture from | |
| `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" <https://arxiv.org/abs/2010.11929>`_. | |
| Args: | |
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |
| progress (bool): If True, displays a progress bar of the download to stderr | |
| """ | |
| return _vision_transformer( | |
| arch="vit_l_32", | |
| patch_size=32, | |
| num_layers=24, | |
| num_heads=16, | |
| hidden_dim=1024, | |
| mlp_dim=4096, | |
| pretrained=pretrained, | |
| progress=progress, | |
| **kwargs, | |
| ) | |
| def interpolate_embeddings( | |
| image_size: int, | |
| patch_size: int, | |
| model_state: "OrderedDict[str, torch.Tensor]", | |
| interpolation_mode: str = "bicubic", | |
| reset_heads: bool = False, | |
| ) -> "OrderedDict[str, torch.Tensor]": | |
| """This function helps interpolating positional embeddings during checkpoint loading, | |
| especially when you want to apply a pre-trained model on images with different resolution. | |
| Args: | |
| image_size (int): Image size of the new model. | |
| patch_size (int): Patch size of the new model. | |
| model_state (OrderedDict[str, torch.Tensor]): State dict of the pre-trained model. | |
| interpolation_mode (str): The algorithm used for upsampling. Default: bicubic. | |
| reset_heads (bool): If true, not copying the state of heads. Default: False. | |
| Returns: | |
| OrderedDict[str, torch.Tensor]: A state dict which can be loaded into the new model. | |
| """ | |
| # Shape of pos_embedding is (1, seq_length, hidden_dim) | |
| pos_embedding = model_state["encoder.pos_embedding"] | |
| n, seq_length, hidden_dim = pos_embedding.shape | |
| if n != 1: | |
| raise ValueError(f"Unexpected position embedding shape: {pos_embedding.shape}") | |
| new_seq_length = (image_size // patch_size) ** 2 + 1 | |
| # Need to interpolate the weights for the position embedding. | |
| # We do this by reshaping the positions embeddings to a 2d grid, performing | |
| # an interpolation in the (h, w) space and then reshaping back to a 1d grid. | |
| if new_seq_length != seq_length: | |
| # The class token embedding shouldn't be interpolated so we split it up. | |
| seq_length -= 1 | |
| new_seq_length -= 1 | |
| pos_embedding_token = pos_embedding[:, :1, :] | |
| pos_embedding_img = pos_embedding[:, 1:, :] | |
| # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) | |
| pos_embedding_img = pos_embedding_img.permute(0, 2, 1) | |
| seq_length_1d = int(math.sqrt(seq_length)) | |
| torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") | |
| # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) | |
| pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) | |
| new_seq_length_1d = image_size // patch_size | |
| # Perform interpolation. | |
| # (1, hidden_dim, seq_l_1d, seq_l_1d) -> (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) | |
| new_pos_embedding_img = nn.functional.interpolate( | |
| pos_embedding_img, | |
| size=new_seq_length_1d, | |
| mode=interpolation_mode, | |
| align_corners=True, | |
| ) | |
| # (1, hidden_dim, new_seq_l_1d, new_seq_l_1d) -> (1, hidden_dim, new_seq_length) | |
| new_pos_embedding_img = new_pos_embedding_img.reshape(1, hidden_dim, new_seq_length) | |
| # (1, hidden_dim, new_seq_length) -> (1, new_seq_length, hidden_dim) | |
| new_pos_embedding_img = new_pos_embedding_img.permute(0, 2, 1) | |
| new_pos_embedding = torch.cat([pos_embedding_token, new_pos_embedding_img], dim=1) | |
| model_state["encoder.pos_embedding"] = new_pos_embedding | |
| if reset_heads: | |
| model_state_copy: "OrderedDict[str, torch.Tensor]" = OrderedDict() | |
| for k, v in model_state.items(): | |
| if not k.startswith("heads"): | |
| model_state_copy[k] = v | |
| model_state = model_state_copy | |
| return model_state | |