Spaces:
Build error
Build error
| # ---------------------------------------------------------------------------- | |
| # SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
| # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
| # Code based on fairseq: https://github.com/facebookresearch/fairseq | |
| # | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # ---------------------------------------------------------------------------- | |
| """ | |
| We just merge all the required modules and functions into one python file. | |
| It is for easily use the pre-trained model to extract features. | |
| """ | |
| import math | |
| import numpy as np | |
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn import Parameter | |
| from torch import Tensor | |
| from typing import Any, Dict, List, Tuple, Callable, Optional | |
| logger = logging.getLogger(__name__) | |
| # rewrite name for backward compatibility in `make_generation_fast_` | |
| def module_name_fordropout(module_name: str) -> str: | |
| if module_name == "TransformerEncoderBase": | |
| return "TransformerEncoder" | |
| else: | |
| return module_name | |
| def utils_make_positions(tensor, padding_idx: int, onnx_trace: bool = False): | |
| """Replace non-padding symbols with their position numbers. | |
| Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
| """ | |
| # The series of casts and type-conversions here are carefully | |
| # balanced to both work with ONNX export and XLA. In particular XLA | |
| # prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
| # how to handle the dtype kwarg in cumsum. | |
| mask = tensor.ne(padding_idx).int() | |
| return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx | |
| def utils_item(tensor): | |
| # tpu-comment: making this a no-op for xla devices. | |
| if torch.is_tensor(tensor) and tensor.device.type == "xla": | |
| return tensor.detach() | |
| if hasattr(tensor, "item"): | |
| return tensor.item() | |
| if hasattr(tensor, "__getitem__"): | |
| return tensor[0] | |
| return tensor | |
| def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs): | |
| """ | |
| Helper to wrap layers/modules in FSDP. This falls back to a no-op if | |
| fairscale is not available. | |
| Args: | |
| module (nn.Module): module to (maybe) wrap | |
| min_num_params (int, Optional): minimum number of layer params to wrap | |
| """ | |
| try: | |
| from fairscale.nn import wrap | |
| if min_num_params is not None: | |
| num_params = sum(p.numel() for p in module.parameters()) | |
| if num_params >= min_num_params: | |
| return wrap(module, **kwargs) | |
| else: | |
| return module | |
| else: | |
| return wrap(module, **kwargs) | |
| except ImportError: | |
| return module | |
| def quant_noise(module, p, block_size): | |
| """ | |
| Wraps modules and applies quantization noise to the weights for | |
| subsequent quantization with Iterative Product Quantization as | |
| described in "Training with Quantization Noise for Extreme Model Compression" | |
| Args: | |
| - module: nn.Module | |
| - p: amount of Quantization Noise | |
| - block_size: size of the blocks for subsequent quantization with iPQ | |
| Remarks: | |
| - Module weights must have the right sizes wrt the block size | |
| - Only Linear, Embedding and Conv2d modules are supported for the moment | |
| - For more detail on how to quantize by blocks with convolutional weights, | |
| see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" | |
| - We implement the simplest form of noise here as stated in the paper | |
| which consists in randomly dropping blocks | |
| """ | |
| # if no quantization noise, don't register hook | |
| if p <= 0: | |
| return module | |
| # supported modules | |
| assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) | |
| # test whether module.weight has the right sizes wrt block_size | |
| is_conv = module.weight.ndim == 4 | |
| # 2D matrix | |
| if not is_conv: | |
| assert ( | |
| module.weight.size(1) % block_size == 0 | |
| ), "Input features must be a multiple of block sizes" | |
| # 4D matrix | |
| else: | |
| # 1x1 convolutions | |
| if module.kernel_size == (1, 1): | |
| assert ( | |
| module.in_channels % block_size == 0 | |
| ), "Input channels must be a multiple of block sizes" | |
| # regular convolutions | |
| else: | |
| k = module.kernel_size[0] * module.kernel_size[1] | |
| assert k % block_size == 0, "Kernel size must be a multiple of block size" | |
| def _forward_pre_hook(mod, input): | |
| # no noise for evaluation | |
| if mod.training: | |
| if not is_conv: | |
| # gather weight and sizes | |
| weight = mod.weight | |
| in_features = weight.size(1) | |
| out_features = weight.size(0) | |
| # split weight matrix into blocks and randomly drop selected blocks | |
| mask = torch.zeros( | |
| in_features // block_size * out_features, device=weight.device | |
| ) | |
| mask.bernoulli_(p) | |
| mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) | |
| else: | |
| # gather weight and sizes | |
| weight = mod.weight | |
| in_channels = mod.in_channels | |
| out_channels = mod.out_channels | |
| # split weight matrix into blocks and randomly drop selected blocks | |
| if mod.kernel_size == (1, 1): | |
| mask = torch.zeros( | |
| int(in_channels // block_size * out_channels), | |
| device=weight.device, | |
| ) | |
| mask.bernoulli_(p) | |
| mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) | |
| else: | |
| mask = torch.zeros( | |
| weight.size(0), weight.size(1), device=weight.device | |
| ) | |
| mask.bernoulli_(p) | |
| mask = ( | |
| mask.unsqueeze(2) | |
| .unsqueeze(3) | |
| .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) | |
| ) | |
| # scale weights and apply mask | |
| mask = mask.to( | |
| torch.bool | |
| ) # x.bool() is not currently supported in TorchScript | |
| s = 1 / (1 - p) | |
| mod.weight.data = s * weight.masked_fill(mask, 0) | |
| module.register_forward_pre_hook(_forward_pre_hook) | |
| return module | |
| def relu_squared(x: torch.Tensor): | |
| return F.relu(x).pow(2) | |
| def gelu(x: torch.Tensor) -> torch.Tensor: | |
| return torch.nn.functional.gelu(x.float()).type_as(x) | |
| def gelu_accurate(x): | |
| if not hasattr(gelu_accurate, "_a"): | |
| gelu_accurate._a = math.sqrt(2 / math.pi) | |
| return ( | |
| 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) | |
| ) | |
| def get_activation_fn(activation: str) -> Callable: | |
| """Returns the activation function corresponding to `activation`""" | |
| if activation == "relu": | |
| return F.relu | |
| elif activation == "relu_squared": | |
| return relu_squared | |
| elif activation == "gelu": | |
| return gelu | |
| elif activation == "gelu_fast": | |
| logger.warn( | |
| "--activation-fn=gelu_fast has been renamed to gelu_accurate" | |
| ) | |
| return gelu_accurate | |
| elif activation == "gelu_accurate": | |
| return gelu_accurate | |
| elif activation == "tanh": | |
| return torch.tanh | |
| elif activation == "linear": | |
| return lambda x: x | |
| elif activation == "swish": | |
| return torch.nn.SiLU | |
| else: | |
| raise RuntimeError("--activation-fn {} not supported".format(activation)) | |
| def softmax(x, dim: int, onnx_trace: bool = False): | |
| if onnx_trace: | |
| return F.softmax(x.float(), dim=dim) | |
| else: | |
| return F.softmax(x, dim=dim, dtype=torch.float32) | |
| def compute_mask_indices( | |
| shape: Tuple[int, int], | |
| padding_mask: Optional[torch.Tensor], | |
| mask_prob: float, | |
| mask_length: int, | |
| mask_type: str = "static", | |
| mask_other: float = 0.0, | |
| min_masks: int = 0, | |
| no_overlap: bool = False, | |
| min_space: int = 0, | |
| require_same_masks: bool = True, | |
| mask_dropout: float = 0.0, | |
| ) -> np.ndarray: | |
| """ | |
| Computes random mask spans for a given shape | |
| Args: | |
| shape: the the shape for which to compute masks. | |
| should be of size 2 where first element is batch size and 2nd is timesteps | |
| padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements | |
| mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by | |
| number of timesteps divided by length of mask span to mask approximately this percentage of all elements. | |
| however due to overlaps, the actual number will be smaller (unless no_overlap is True) | |
| mask_type: how to compute mask lengths | |
| static = fixed size | |
| uniform = sample from uniform distribution [mask_other, mask_length*2] | |
| normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element | |
| poisson = sample from possion distribution with lambda = mask length | |
| min_masks: minimum number of masked spans | |
| no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping | |
| min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans | |
| require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample | |
| mask_dropout: randomly dropout this percentage of masks in each example | |
| """ | |
| bsz, all_sz = shape | |
| mask = np.full((bsz, all_sz), False) | |
| all_num_mask = int( | |
| # add a random number for probabilistic rounding | |
| mask_prob * all_sz / float(mask_length) | |
| + np.random.rand() | |
| ) | |
| all_num_mask = max(min_masks, all_num_mask) | |
| mask_idcs = [] | |
| for i in range(bsz): | |
| if padding_mask is not None: | |
| sz = all_sz - padding_mask[i].long().sum().item() | |
| num_mask = int( | |
| # add a random number for probabilistic rounding | |
| mask_prob * sz / float(mask_length) | |
| + np.random.rand() | |
| ) | |
| num_mask = max(min_masks, num_mask) | |
| else: | |
| sz = all_sz | |
| num_mask = all_num_mask | |
| if mask_type == "static": | |
| lengths = np.full(num_mask, mask_length) | |
| elif mask_type == "uniform": | |
| lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) | |
| elif mask_type == "normal": | |
| lengths = np.random.normal(mask_length, mask_other, size=num_mask) | |
| lengths = [max(1, int(round(x))) for x in lengths] | |
| elif mask_type == "poisson": | |
| lengths = np.random.poisson(mask_length, size=num_mask) | |
| lengths = [int(round(x)) for x in lengths] | |
| else: | |
| raise Exception("unknown mask selection " + mask_type) | |
| if sum(lengths) == 0: | |
| lengths[0] = min(mask_length, sz - 1) | |
| if no_overlap: | |
| mask_idc = [] | |
| def arrange(s, e, length, keep_length): | |
| span_start = np.random.randint(s, e - length) | |
| mask_idc.extend(span_start + i for i in range(length)) | |
| new_parts = [] | |
| if span_start - s - min_space >= keep_length: | |
| new_parts.append((s, span_start - min_space + 1)) | |
| if e - span_start - keep_length - min_space > keep_length: | |
| new_parts.append((span_start + length + min_space, e)) | |
| return new_parts | |
| parts = [(0, sz)] | |
| min_length = min(lengths) | |
| for length in sorted(lengths, reverse=True): | |
| lens = np.fromiter( | |
| (e - s if e - s >= length + min_space else 0 for s, e in parts), | |
| np.int, | |
| ) | |
| l_sum = np.sum(lens) | |
| if l_sum == 0: | |
| break | |
| probs = lens / np.sum(lens) | |
| c = np.random.choice(len(parts), p=probs) | |
| s, e = parts.pop(c) | |
| parts.extend(arrange(s, e, length, min_length)) | |
| mask_idc = np.asarray(mask_idc) | |
| else: | |
| min_len = min(lengths) | |
| if sz - min_len <= num_mask: | |
| min_len = sz - num_mask - 1 | |
| mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) | |
| mask_idc = np.asarray( | |
| [ | |
| mask_idc[j] + offset | |
| for j in range(len(mask_idc)) | |
| for offset in range(lengths[j]) | |
| ] | |
| ) | |
| mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) | |
| min_len = min([len(m) for m in mask_idcs]) | |
| for i, mask_idc in enumerate(mask_idcs): | |
| if len(mask_idc) > min_len and require_same_masks: | |
| mask_idc = np.random.choice(mask_idc, min_len, replace=False) | |
| if mask_dropout > 0: | |
| num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) | |
| mask_idc = np.random.choice( | |
| mask_idc, len(mask_idc) - num_holes, replace=False | |
| ) | |
| mask[i, mask_idc] = True | |
| return mask | |
| def init_bert_params(module): | |
| """ | |
| Initialize the weights specific to the BERT Model. | |
| This overrides the default initializations depending on the specified arguments. | |
| 1. If normal_init_linear_weights is set then weights of linear | |
| layer will be initialized using the normal distribution and | |
| bais will be set to the specified value. | |
| 2. If normal_init_embed_weights is set then weights of embedding | |
| layer will be initialized using the normal distribution. | |
| 3. If normal_init_proj_weights is set then weights of | |
| in_project_weight for MultiHeadAttention initialized using | |
| the normal distribution (to be validated). | |
| """ | |
| def normal_(data): | |
| # with FSDP, module params will be on CUDA, so we cast them back to CPU | |
| # so that the RNG is consistent with and without FSDP | |
| data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) | |
| if isinstance(module, nn.Linear): | |
| normal_(module.weight.data) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| if isinstance(module, nn.Embedding): | |
| normal_(module.weight.data) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| if isinstance(module, MultiheadAttention): | |
| normal_(module.q_proj.weight.data) | |
| normal_(module.k_proj.weight.data) | |
| normal_(module.v_proj.weight.data) | |
| def pad_to_multiple(x, multiple, dim=-1, value=0): | |
| # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41 | |
| if x is None: | |
| return None, 0 | |
| tsz = x.size(dim) | |
| m = tsz / multiple | |
| remainder = math.ceil(m) * multiple - tsz | |
| if m.is_integer(): | |
| return x, 0 | |
| pad_offset = (0,) * (-1 - dim) * 2 | |
| return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder | |
| def is_xla_tensor(tensor): | |
| return torch.is_tensor(tensor) and tensor.device.type == "xla" | |
| def index_put(tensor, indices, value): | |
| if is_xla_tensor(tensor): | |
| for _ in range(indices.dim(), tensor.dim()): | |
| indices = indices.unsqueeze(-1) | |
| if indices.size(-1) < tensor.size(-1): | |
| indices = indices.expand_as(tensor) | |
| tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) | |
| else: | |
| tensor[indices] = value | |
| return tensor | |
| def PositionalEmbedding( | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| padding_idx: int, | |
| learned: bool = False, | |
| ): | |
| if learned: | |
| # if padding_idx is specified then offset the embedding ids by | |
| # this index and adjust num_embeddings appropriately | |
| # TODO: The right place for this offset would be inside | |
| # LearnedPositionalEmbedding. Move this there for a cleaner implementation. | |
| if padding_idx is not None: | |
| num_embeddings = num_embeddings + padding_idx + 1 | |
| m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx) | |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) | |
| if padding_idx is not None: | |
| nn.init.constant_(m.weight[padding_idx], 0) | |
| else: | |
| m = SinusoidalPositionalEmbedding( | |
| embedding_dim, | |
| padding_idx, | |
| init_size=num_embeddings + padding_idx + 1, | |
| ) | |
| return m | |
| def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): | |
| if torch.jit.is_scripting() or torch.jit.is_tracing(): | |
| export = True | |
| if not export and torch.cuda.is_available() and has_fused_layernorm: | |
| return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | |
| return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |
| class TransformerEncoderBase(nn.Module): | |
| """ | |
| Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer | |
| is a :class:`TransformerEncoderLayer`. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| dictionary: deprecated(None) | |
| embed_tokens (torch.nn.Embedding): input embedding | |
| """ | |
| def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0): | |
| self.cfg = cfg | |
| super().__init__() | |
| self.register_buffer("version", torch.Tensor([3])) | |
| self.dropout_module = FairseqDropout( | |
| cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) | |
| ) | |
| self.encoder_layerdrop = cfg.encoder.layerdrop | |
| embed_dim = embed_tokens.embedding_dim if embed_tokens is not None else cfg.encoder.embed_dim | |
| self.padding_idx = embed_tokens.padding_idx if embed_tokens is not None else 1 | |
| self.max_source_positions = cfg.max_source_positions | |
| self.embed_tokens = embed_tokens | |
| self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) | |
| self.embed_positions = ( | |
| PositionalEmbedding( | |
| cfg.max_source_positions, | |
| embed_dim, | |
| self.padding_idx, | |
| learned=cfg.encoder.learned_pos, | |
| ) | |
| if not cfg.no_token_positional_embeddings | |
| else None | |
| ) | |
| if cfg.layernorm_embedding: | |
| self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) | |
| else: | |
| self.layernorm_embedding = None | |
| if not cfg.adaptive_input and cfg.quant_noise.pq > 0: | |
| self.quant_noise = quant_noise( | |
| nn.Linear(embed_dim, embed_dim, bias=False), | |
| cfg.quant_noise.pq, | |
| cfg.quant_noise.pq_block_size, | |
| ) | |
| else: | |
| self.quant_noise = None | |
| if self.encoder_layerdrop > 0.0: | |
| self.layers = LayerDropModuleList(p=self.encoder_layerdrop) | |
| else: | |
| self.layers = nn.ModuleList([]) | |
| self.use_rel_pos_enc = use_rel_pos_enc | |
| self.scaling_for_att = scaling_for_att | |
| self.layers.extend( | |
| [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] | |
| ) | |
| self.num_layers = len(self.layers) | |
| if cfg.encoder.normalize_before: | |
| self.layer_norm = LayerNorm(embed_dim, export=cfg.export) | |
| else: | |
| self.layer_norm = None | |
| if self.use_rel_pos_enc: | |
| self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160) | |
| def build_encoder_layer(self, cfg): | |
| layer = TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att) | |
| checkpoint = cfg.checkpoint_activations | |
| if checkpoint: | |
| raise ValueError("We don't support checkpoint_activations for now! Please set cfg.checkpoint_activations=False.") | |
| min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 | |
| layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) | |
| return layer | |
| def forward_embedding( | |
| self, src_tokens, token_embedding: Optional[torch.Tensor] = None | |
| ): | |
| # embed tokens and positions | |
| if token_embedding is None: | |
| token_embedding = self.embed_tokens(src_tokens) | |
| x = embed = self.embed_scale * token_embedding | |
| if self.embed_positions is not None: | |
| x = embed + self.embed_positions(src_tokens) | |
| if self.layernorm_embedding is not None: | |
| x = self.layernorm_embedding(x) | |
| x = self.dropout_module(x) | |
| if self.quant_noise is not None: | |
| x = self.quant_noise(x) | |
| return x, embed | |
| def forward( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| uniformity_layers: Optional[List[int]] = None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| return self.forward_scriptable( | |
| src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers | |
| ) | |
| # TorchScript doesn't support super() method so that the scriptable Subclass | |
| # can't access the base class model in Torchscript. | |
| # Current workaround is to add a helper function with different name and | |
| # call the helper function from scriptable Subclass. | |
| def forward_scriptable( | |
| self, | |
| src_tokens, | |
| src_lengths: Optional[torch.Tensor] = None, | |
| return_all_hiddens: bool = False, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| uniformity_layers: Optional[List[int]] = None, | |
| ): | |
| """ | |
| Args: | |
| src_tokens (LongTensor): tokens in the source language of shape | |
| `(batch, src_len)` | |
| src_lengths (torch.LongTensor): lengths of each source sentence of | |
| shape `(batch)` | |
| return_all_hiddens (bool, optional): also return all of the | |
| intermediate hidden states (default: False). | |
| token_embeddings (torch.Tensor, optional): precomputed embeddings | |
| default `None` will recompute embeddings | |
| Returns: | |
| dict: | |
| - **encoder_out** (Tensor): the last encoder layer's output of | |
| shape `(src_len, batch, embed_dim)` | |
| - **encoder_padding_mask** (ByteTensor): the positions of | |
| padding elements of shape `(batch, src_len)` | |
| - **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
| of shape `(batch, src_len, embed_dim)` | |
| - **encoder_states** (List[Tensor]): all intermediate | |
| hidden states of shape `(src_len, batch, embed_dim)`. | |
| Only populated if *return_all_hiddens* is True. | |
| """ | |
| # compute padding mask | |
| encoder_padding_mask = src_tokens.eq(self.padding_idx) | |
| has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() | |
| x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) | |
| # account for padding while computing the representation | |
| if has_pads: | |
| x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| if self.use_rel_pos_enc: | |
| x_len = x.shape[0] | |
| pos_seq = torch.arange(0, x_len).long().to(x.device) | |
| pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
| pos_k, pos_v = self.pos_emb(pos_seq) | |
| else: | |
| pos_k = None | |
| encoder_states = [] | |
| uniformity_hiddens = [] | |
| if return_all_hiddens: | |
| encoder_states.append(x) | |
| if uniformity_layers is not None and 0 in uniformity_layers: | |
| x = F.normalize(x.float(), dim=-1).type_as(x) | |
| uniformity_hiddens.append(x) | |
| # encoder layers | |
| for i, layer in enumerate(self.layers): | |
| x = layer( | |
| x, encoder_padding_mask=encoder_padding_mask if has_pads else None, | |
| pos_bias=pos_k, | |
| ) | |
| if uniformity_layers is not None and i+1 in uniformity_layers: | |
| x = F.normalize(x.float(), dim=-1).type_as(x) | |
| uniformity_hiddens.append(x) | |
| if return_all_hiddens: | |
| assert encoder_states is not None | |
| encoder_states.append(x) | |
| if self.layer_norm is not None: | |
| x = self.layer_norm(x) | |
| # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
| # `forward` so we use a dictionary instead. | |
| # TorchScript does not support mixed values so the values are all lists. | |
| # The empty list is equivalent to None. | |
| src_lengths = ( | |
| src_tokens.ne(self.padding_idx) | |
| .sum(dim=1, dtype=torch.int32) | |
| .reshape(-1, 1) | |
| .contiguous() | |
| ) | |
| return { | |
| "encoder_out": [x], # T x B x C | |
| "encoder_padding_mask": [encoder_padding_mask], # B x T | |
| "encoder_embedding": [encoder_embedding], # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "uniformity_hiddens": uniformity_hiddens, # List[T x B x C] | |
| "src_tokens": [], | |
| "src_lengths": [src_lengths], | |
| } | |
| def forward_torchscript(self, net_input: Dict[str, Tensor]): | |
| """A TorchScript-compatible version of forward. | |
| Encoders which use additional arguments may want to override | |
| this method for TorchScript compatibility. | |
| """ | |
| if torch.jit.is_scripting(): | |
| return self.forward( | |
| src_tokens=net_input["src_tokens"], | |
| src_lengths=net_input["src_lengths"], | |
| ) | |
| else: | |
| return self.forward_non_torchscript(net_input) | |
| def forward_non_torchscript(self, net_input: Dict[str, Tensor]): | |
| encoder_input = { | |
| k: v for k, v in net_input.items() if k != "prev_output_tokens" | |
| } | |
| return self.forward(**encoder_input) | |
| def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): | |
| """ | |
| Reorder encoder output according to *new_order*. | |
| Args: | |
| encoder_out: output from the ``forward()`` method | |
| new_order (LongTensor): desired order | |
| Returns: | |
| *encoder_out* rearranged according to *new_order* | |
| """ | |
| if len(encoder_out["encoder_out"]) == 0: | |
| new_encoder_out = [] | |
| else: | |
| new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] | |
| if len(encoder_out["encoder_padding_mask"]) == 0: | |
| new_encoder_padding_mask = [] | |
| else: | |
| new_encoder_padding_mask = [ | |
| encoder_out["encoder_padding_mask"][0].index_select(0, new_order) | |
| ] | |
| if len(encoder_out["encoder_embedding"]) == 0: | |
| new_encoder_embedding = [] | |
| else: | |
| new_encoder_embedding = [ | |
| encoder_out["encoder_embedding"][0].index_select(0, new_order) | |
| ] | |
| if len(encoder_out["src_tokens"]) == 0: | |
| src_tokens = [] | |
| else: | |
| src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] | |
| if len(encoder_out["src_lengths"]) == 0: | |
| src_lengths = [] | |
| else: | |
| src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] | |
| encoder_states = encoder_out["encoder_states"] | |
| if len(encoder_states) > 0: | |
| for idx, state in enumerate(encoder_states): | |
| encoder_states[idx] = state.index_select(1, new_order) | |
| return { | |
| "encoder_out": new_encoder_out, # T x B x C | |
| "encoder_padding_mask": new_encoder_padding_mask, # B x T | |
| "encoder_embedding": new_encoder_embedding, # B x T x C | |
| "encoder_states": encoder_states, # List[T x B x C] | |
| "src_tokens": src_tokens, # B x T | |
| "src_lengths": src_lengths, # B x 1 | |
| } | |
| def max_positions(self): | |
| """Maximum input length supported by the encoder.""" | |
| if self.embed_positions is None: | |
| return self.max_source_positions | |
| return min(self.max_source_positions, self.embed_positions.max_positions) | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade a (possibly old) state dict for new versions.""" | |
| if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): | |
| weights_key = "{}.embed_positions.weights".format(name) | |
| if weights_key in state_dict: | |
| print("deleting {0}".format(weights_key)) | |
| del state_dict[weights_key] | |
| state_dict[ | |
| "{}.embed_positions._float_tensor".format(name) | |
| ] = torch.FloatTensor(1) | |
| for i in range(self.num_layers): | |
| # update layer norms | |
| self.layers[i].upgrade_state_dict_named( | |
| state_dict, "{}.layers.{}".format(name, i) | |
| ) | |
| version_key = "{}.version".format(name) | |
| if utils_item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: | |
| # earlier checkpoints did not normalize after the stack of layers | |
| self.layer_norm = None | |
| self.normalize = False | |
| state_dict[version_key] = torch.Tensor([1]) | |
| return state_dict | |
| def set_num_updates(self, num_updates): | |
| """State from trainer to pass along to model at every update.""" | |
| def _apply(m): | |
| if hasattr(m, "set_num_updates") and m != self: | |
| m.set_num_updates(num_updates) | |
| self.apply(_apply) | |
| class TransformerEncoderLayerBase(nn.Module): | |
| """Encoder layer block. | |
| In the original paper each operation (multi-head attention or FFN) is | |
| postprocessed with: `dropout -> add residual -> layernorm`. In the | |
| tensor2tensor code they suggest that learning is more robust when | |
| preprocessing each layer with layernorm and postprocessing with: | |
| `dropout -> add residual`. We default to the approach in the paper, but the | |
| tensor2tensor approach can be enabled by setting | |
| *cfg.encoder.normalize_before* to ``True``. | |
| Args: | |
| args (argparse.Namespace): parsed command-line arguments | |
| """ | |
| def __init__(self, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed_dim = cfg.encoder.embed_dim | |
| self.quant_noise = cfg.quant_noise.pq | |
| self.quant_noise_block_size = cfg.quant_noise.pq_block_size | |
| self.self_attn = self.build_self_attention(self.embed_dim, cfg, has_relative_attention_bias=has_relative_attention_bias, scaling_for_att=scaling_for_att) | |
| self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| self.dropout_module = FairseqDropout( | |
| cfg.dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.activation_fn = get_activation_fn(activation=cfg.activation_fn) | |
| activation_dropout_p = cfg.activation_dropout | |
| if activation_dropout_p == 0: | |
| # for backwards compatibility with models that use cfg.relu_dropout | |
| activation_dropout_p = cfg.relu_dropout or 0 | |
| self.activation_dropout_module = FairseqDropout( | |
| float(activation_dropout_p), module_name=self.__class__.__name__ | |
| ) | |
| self.normalize_before = cfg.encoder.normalize_before | |
| self.fc1 = self.build_fc1( | |
| self.embed_dim, | |
| cfg.encoder.ffn_embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.fc2 = self.build_fc2( | |
| cfg.encoder.ffn_embed_dim, | |
| self.embed_dim, | |
| self.quant_noise, | |
| self.quant_noise_block_size, | |
| ) | |
| self.final_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) | |
| if has_relative_attention_bias: | |
| self.norm_k = LayerNorm(self.embed_dim // cfg.encoder.attention_heads) | |
| def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise( | |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
| ) | |
| def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): | |
| return quant_noise( | |
| nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size | |
| ) | |
| def build_self_attention(self, embed_dim, cfg, has_relative_attention_bias=False, scaling_for_att=1.0): | |
| return MultiheadAttention( | |
| embed_dim, | |
| cfg.encoder.attention_heads, | |
| dropout=cfg.attention_dropout, | |
| self_attention=True, | |
| q_noise=self.quant_noise, | |
| qn_block_size=self.quant_noise_block_size, | |
| has_relative_attention_bias=has_relative_attention_bias, | |
| scaling_for_att=scaling_for_att, | |
| ) | |
| def residual_connection(self, x, residual): | |
| return residual + x | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """ | |
| Rename layer norm states from `...layer_norms.0.weight` to | |
| `...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to | |
| `...final_layer_norm.weight` | |
| """ | |
| layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} | |
| for old, new in layer_norm_map.items(): | |
| for m in ("weight", "bias"): | |
| k = "{}.layer_norms.{}.{}".format(name, old, m) | |
| if k in state_dict: | |
| state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] | |
| del state_dict[k] | |
| def forward( | |
| self, | |
| x, | |
| encoder_padding_mask: Optional[Tensor], | |
| attn_mask: Optional[Tensor] = None, | |
| pos_bias=None, | |
| ): | |
| """ | |
| Args: | |
| x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
| encoder_padding_mask (ByteTensor): binary ByteTensor of shape | |
| `(batch, seq_len)` where padding elements are indicated by ``1``. | |
| attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, | |
| where `tgt_len` is the length of output and `src_len` is the | |
| length of input, though here both are equal to `seq_len`. | |
| `attn_mask[tgt_i, src_j] = 1` means that when calculating the | |
| embedding for `tgt_i`, we exclude (mask out) `src_j`. This is | |
| useful for strided self-attention. | |
| Returns: | |
| encoded output of shape `(seq_len, batch, embed_dim)` | |
| """ | |
| # anything in original attn_mask = 1, becomes -1e8 | |
| # anything in original attn_mask = 0, becomes 0 | |
| # Note that we cannot use -inf here, because at some edge cases, | |
| # the attention weight (before softmax) for some padded element in query | |
| # will become -inf, which results in NaN in model parameters | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.masked_fill( | |
| attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 | |
| ) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| if pos_bias is not None: | |
| pos_bias = self.norm_k(pos_bias) | |
| x, _ = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=encoder_padding_mask, | |
| need_weights=False, | |
| attn_mask=attn_mask, | |
| position_bias=pos_bias, | |
| ) | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.self_attn_layer_norm(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.activation_dropout_module(x) | |
| x = self.fc2(x) | |
| x = self.dropout_module(x) | |
| x = self.residual_connection(x, residual) | |
| if not self.normalize_before: | |
| x = self.final_layer_norm(x) | |
| return x | |
| class TransformerEncoder(nn.Module): | |
| """ | |
| wav2vec-style transformer encoder. | |
| """ | |
| def __init__(self, args): | |
| super().__init__() | |
| self.dropout = args.dropout | |
| self.embedding_dim = args.encoder_embed_dim | |
| self.required_seq_len_multiple = args.required_seq_len_multiple | |
| self.pos_conv = nn.Conv1d( | |
| self.embedding_dim, | |
| self.embedding_dim, | |
| kernel_size=args.conv_pos, | |
| padding=args.conv_pos // 2, | |
| groups=args.conv_pos_groups, | |
| ) | |
| dropout = 0 | |
| std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) | |
| nn.init.normal_(self.pos_conv.weight, mean=0, std=std) | |
| nn.init.constant_(self.pos_conv.bias, 0) | |
| self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) | |
| self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) | |
| layers = [] | |
| self.use_rel_pos_enc = getattr(args, "use_rel_pos_enc", False) | |
| for _ in range(args.encoder_layers): | |
| layer = TransformerSentenceEncoderLayer( | |
| embedding_dim=self.embedding_dim, | |
| ffn_embedding_dim=args.encoder_ffn_embed_dim, | |
| num_attention_heads=args.encoder_attention_heads, | |
| dropout=self.dropout, | |
| attention_dropout=args.attention_dropout, | |
| activation_dropout=args.activation_dropout, | |
| activation_fn=args.activation_fn, | |
| layer_norm_first=args.layer_norm_first, | |
| has_relative_attention_bias=self.use_rel_pos_enc, | |
| scaling_for_att=getattr(args, "scaling_for_att", 1.0) | |
| ) | |
| if args.checkpoint_activations: | |
| raise ValueError("We don't support checkpoint_activations for now! Please set checkpoint_activations=False.") | |
| layers.append(layer) | |
| self.layers = nn.ModuleList(layers) | |
| self.layer_norm_first = args.layer_norm_first | |
| self.layer_norm = LayerNorm(self.embedding_dim) | |
| self.layerdrop = args.encoder_layerdrop | |
| if self.use_rel_pos_enc: | |
| self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim // args.encoder_attention_heads, 160) | |
| self.apply(init_bert_params) | |
| def forward(self, x, padding_mask=None, layer=None, conv_pos=True): | |
| x, layer_results = self.extract_features(x, padding_mask, layer, conv_pos) | |
| if self.layer_norm_first and (layer is None or layer >= len(self.layers) - 1): | |
| x = self.layer_norm(x) | |
| return x, layer_results | |
| def extract_features(self, x, padding_mask=None, tgt_layer=None, conv_pos=True): | |
| if padding_mask is not None: | |
| x = index_put(x, padding_mask, 0) | |
| if conv_pos: | |
| x_conv = self.pos_conv(x.transpose(1, 2)) | |
| x_conv = x_conv.transpose(1, 2) | |
| x = x + x_conv | |
| if not self.layer_norm_first: | |
| x = self.layer_norm(x) | |
| # pad to the sequence length dimension | |
| x, pad_length = pad_to_multiple( | |
| x, self.required_seq_len_multiple, dim=-2, value=0 | |
| ) | |
| if pad_length > 0 and padding_mask is None: | |
| padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) | |
| padding_mask[:, -pad_length:] = True | |
| else: | |
| padding_mask, _ = pad_to_multiple( | |
| padding_mask, self.required_seq_len_multiple, dim=-1, value=True | |
| ) | |
| x = F.dropout(x, p=self.dropout, training=self.training) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| if self.use_rel_pos_enc: | |
| x_len = x.shape[0] | |
| pos_seq = torch.arange(0, x_len).long().to(x.device) | |
| pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
| pos_k, pos_v = self.pos_emb(pos_seq) | |
| else: | |
| pos_k = None | |
| layer_results = [] | |
| r = None | |
| for i, layer in enumerate(self.layers): | |
| dropout_probability = np.random.random() | |
| if not self.training or (dropout_probability > self.layerdrop): | |
| x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_k) | |
| if tgt_layer is not None: | |
| # unpad if needed | |
| if pad_length > 0: | |
| layer_results.append( | |
| x[:-pad_length] | |
| # ( | |
| # x[:-pad_length], | |
| # z[:, :-pad_length, :-pad_length] | |
| # if z is not None | |
| # else z, | |
| # ) | |
| ) | |
| else: | |
| # layer_results.append((x, z)) | |
| layer_results.append(x) | |
| if i == tgt_layer: | |
| r = x | |
| break | |
| if r is not None: | |
| x = r | |
| # T x B x C -> B x T x C | |
| x = x.transpose(0, 1) | |
| # undo paddding | |
| if pad_length > 0: | |
| x = x[:, :-pad_length] | |
| return x, layer_results | |
| def max_positions(self): | |
| """Maximum output length supported by the encoder.""" | |
| return self.args.max_positions | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
| return state_dict | |
| class TransformerSentenceEncoderLayer(nn.Module): | |
| """ | |
| wav2vec-style transformer layer | |
| """ | |
| def __init__( | |
| self, | |
| embedding_dim: float = 768, | |
| ffn_embedding_dim: float = 3072, | |
| num_attention_heads: float = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.1, | |
| activation_fn: str = "relu", | |
| layer_norm_first: bool = False, | |
| has_relative_attention_bias: bool = False, | |
| scaling_for_att: float = 1.0, | |
| ) -> None: | |
| super().__init__() | |
| # Initialize parameters | |
| self.embedding_dim = embedding_dim | |
| self.dropout = dropout | |
| self.activation_dropout = activation_dropout | |
| # Initialize blocks | |
| self.activation_fn = get_activation_fn(activation_fn) | |
| self.self_attn = MultiheadAttention( | |
| self.embedding_dim, | |
| num_attention_heads, | |
| dropout=attention_dropout, | |
| self_attention=True, | |
| has_relative_attention_bias=has_relative_attention_bias, | |
| scaling_for_att=scaling_for_att | |
| ) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(self.activation_dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.layer_norm_first = layer_norm_first | |
| # layer norm associated with the self attention layer | |
| self.self_attn_layer_norm = LayerNorm(self.embedding_dim) | |
| self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) | |
| self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) | |
| # layer norm associated with the position wise feed-forward NN | |
| self.final_layer_norm = LayerNorm(self.embedding_dim) | |
| if has_relative_attention_bias: | |
| self.norm_k = LayerNorm(self.embedding_dim//num_attention_heads) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| self_attn_mask: torch.Tensor = None, | |
| self_attn_padding_mask: torch.Tensor = None, | |
| need_weights: bool = False, | |
| att_args=None, | |
| pos_bias=None, | |
| ): | |
| """ | |
| LayerNorm is applied either before or after the self-attention/ffn | |
| modules similar to the original Transformer imlementation. | |
| """ | |
| residual = x | |
| if self.layer_norm_first: | |
| x = self.self_attn_layer_norm(x) | |
| if pos_bias is not None: | |
| pos_bias = self.norm_k(pos_bias) | |
| x, attn = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=self_attn_padding_mask, | |
| attn_mask=self_attn_mask, | |
| position_bias=pos_bias, | |
| ) | |
| x = self.dropout1(x) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.dropout2(x) | |
| x = self.fc2(x) | |
| x = self.dropout3(x) | |
| x = residual + x | |
| else: | |
| x, attn = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=self_attn_padding_mask, | |
| position_bias=pos_bias, | |
| ) | |
| x = self.dropout1(x) | |
| x = residual + x | |
| x = self.self_attn_layer_norm(x) | |
| residual = x | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.dropout2(x) | |
| x = self.fc2(x) | |
| x = self.dropout3(x) | |
| x = residual + x | |
| x = self.final_layer_norm(x) | |
| return x, attn | |
| class FairseqDropout(nn.Module): | |
| def __init__(self, p, module_name=None): | |
| super().__init__() | |
| self.p = p | |
| self.module_name = module_name | |
| self.apply_during_inference = False | |
| def forward(self, x, inplace: bool = False): | |
| if self.p > 0 and (self.training or self.apply_during_inference): | |
| return F.dropout(x, p=self.p, training=True, inplace=inplace) | |
| else: | |
| return x | |
| def make_generation_fast_( | |
| self, | |
| name: str, | |
| retain_dropout: bool = False, | |
| retain_dropout_modules: Optional[List[str]] = None, | |
| **kwargs | |
| ): | |
| if retain_dropout: | |
| if retain_dropout_modules is not None and self.module_name is None: | |
| logger.warning( | |
| "Cannot enable dropout during inference for module {} " | |
| "because module_name was not set".format(name) | |
| ) | |
| elif ( | |
| retain_dropout_modules is None # if None, apply to all modules | |
| or self.module_name in retain_dropout_modules | |
| ): | |
| logger.info( | |
| "Enabling dropout during inference for module: {}".format(name) | |
| ) | |
| self.apply_during_inference = True | |
| else: | |
| logger.info("Disabling dropout for module: {}".format(name)) | |
| class LearnedPositionalEmbedding(nn.Embedding): | |
| """ | |
| This module learns positional embeddings up to a fixed maximum size. | |
| Padding ids are ignored by either offsetting based on padding_idx | |
| or by setting padding_idx to None and ensuring that the appropriate | |
| position ids are passed to the forward function. | |
| """ | |
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
| super().__init__(num_embeddings, embedding_dim, padding_idx) | |
| self.onnx_trace = False | |
| if self.padding_idx is not None: | |
| self.max_positions = self.num_embeddings - self.padding_idx - 1 | |
| else: | |
| self.max_positions = self.num_embeddings | |
| def forward( | |
| self, | |
| input: Tensor, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| positions: Optional[Tensor] = None, | |
| ): | |
| """Input is expected to be of size [bsz x seqlen].""" | |
| assert (positions is None) or ( | |
| self.padding_idx is None | |
| ), "If positions is pre-computed then padding_idx should not be set." | |
| if positions is None: | |
| if incremental_state is not None: | |
| # positions is the same for every token when decoding a single step | |
| # Without the int() cast, it doesn't work in some cases when exporting to ONNX | |
| positions = torch.zeros( | |
| (1, 1), device=input.device, dtype=input.dtype | |
| ).fill_(int(self.padding_idx + input.size(1))) | |
| else: | |
| positions = utils_make_positions( | |
| input, self.padding_idx, onnx_trace=self.onnx_trace | |
| ) | |
| positions = torch.clamp(positions, max=self.padding_idx + self.max_positions) | |
| return F.embedding( | |
| positions, | |
| self.weight, | |
| self.padding_idx, | |
| self.max_norm, | |
| self.norm_type, | |
| self.scale_grad_by_freq, | |
| self.sparse, | |
| ) | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| """This module produces sinusoidal positional embeddings of any length. | |
| Padding symbols are ignored. | |
| """ | |
| def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = padding_idx if padding_idx is not None else 0 | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| init_size, embedding_dim, padding_idx | |
| ) | |
| self.onnx_trace = False | |
| self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
| self.max_positions = int(1e5) | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def get_embedding( | |
| num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None | |
| ): | |
| """Build sinusoidal embeddings. | |
| This matches the implementation in tensor2tensor, but differs slightly | |
| from the description in Section 3.5 of "Attention Is All You Need". | |
| """ | |
| half_dim = embedding_dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( | |
| 1 | |
| ) * emb.unsqueeze(0) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( | |
| num_embeddings, -1 | |
| ) | |
| if embedding_dim % 2 == 1: | |
| # zero pad | |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
| if padding_idx is not None: | |
| emb[padding_idx, :] = 0 | |
| return emb | |
| def forward( | |
| self, | |
| input, | |
| incremental_state: Optional[Any] = None, | |
| timestep: Optional[Tensor] = None, | |
| positions: Optional[Any] = None, | |
| ): | |
| """Input is expected to be of size [bsz x seqlen].""" | |
| bspair = torch.onnx.operators.shape_as_tensor(input) | |
| bsz, seq_len = bspair[0], bspair[1] | |
| max_pos = self.padding_idx + 1 + seq_len | |
| if self.weights is None or max_pos > self.weights.size(0): | |
| # recompute/expand embeddings if needed | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| max_pos, self.embedding_dim, self.padding_idx | |
| ) | |
| self.weights = self.weights.to(self._float_tensor) | |
| if incremental_state is not None: | |
| # positions is the same for every token when decoding a single step | |
| pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
| if self.onnx_trace: | |
| return ( | |
| self.weights.index_select(index=self.padding_idx + pos, dim=0) | |
| .unsqueeze(1) | |
| .repeat(bsz, 1, 1) | |
| ) | |
| return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
| positions = utils_make_positions( | |
| input, self.padding_idx, onnx_trace=self.onnx_trace | |
| ) | |
| if self.onnx_trace: | |
| flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) | |
| embedding_shape = torch.cat( | |
| (bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) | |
| ) | |
| embeddings = torch.onnx.operators.reshape_from_tensor_shape( | |
| flat_embeddings, embedding_shape | |
| ) | |
| return embeddings | |
| return ( | |
| self.weights.index_select(0, positions.view(-1)) | |
| .view(bsz, seq_len, -1) | |
| .detach() | |
| ) | |
| try: | |
| from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |
| has_fused_layernorm = True | |
| class FusedLayerNorm(_FusedLayerNorm): | |
| def forward(self, x): | |
| if not x.is_cuda: | |
| return super().forward(x) | |
| else: | |
| with torch.cuda.device(x.device): | |
| return super().forward(x) | |
| except ImportError: | |
| has_fused_layernorm = False | |
| class Fp32LayerNorm(nn.LayerNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.layer_norm( | |
| input.float(), | |
| self.normalized_shape, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| class LayerDropModuleList(nn.ModuleList): | |
| """ | |
| A LayerDrop implementation based on :class:`torch.nn.ModuleList`. | |
| We refresh the choice of which layers to drop every time we iterate | |
| over the LayerDropModuleList instance. During evaluation we always | |
| iterate over all layers. | |
| Usage:: | |
| layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) | |
| for layer in layers: # this might iterate over layers 1 and 3 | |
| x = layer(x) | |
| for layer in layers: # this might iterate over all layers | |
| x = layer(x) | |
| for layer in layers: # this might not iterate over any layers | |
| x = layer(x) | |
| Args: | |
| p (float): probability of dropping out each layer | |
| modules (iterable, optional): an iterable of modules to add | |
| """ | |
| def __init__(self, p, modules=None): | |
| super().__init__(modules) | |
| self.p = p | |
| def __iter__(self): | |
| dropout_probs = torch.empty(len(self)).uniform_() | |
| for i, m in enumerate(super().__iter__()): | |
| if not self.training or (dropout_probs[i] > self.p): | |
| yield m | |
| class RelativePositionalEncoding(torch.nn.Module): | |
| def __init__(self, d_model, maxlen=1000, embed_v=False): | |
| super(RelativePositionalEncoding, self).__init__() | |
| self.d_model = d_model | |
| self.maxlen = maxlen | |
| self.pe_k = torch.nn.Embedding(2*maxlen, d_model) | |
| if embed_v: | |
| self.pe_v = torch.nn.Embedding(2*maxlen, d_model) | |
| self.embed_v = embed_v | |
| def forward(self, pos_seq, incremental_state=None): | |
| pos_seq[pos_seq < -self.maxlen] = -self.maxlen | |
| pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 | |
| pos_seq = pos_seq + self.maxlen | |
| if incremental_state is not None: | |
| pos_seq = pos_seq[-1:] | |
| if self.embed_v: | |
| return self.pe_k(pos_seq), self.pe_v(pos_seq) | |
| else: | |
| return self.pe_k(pos_seq), None | |
| class MultiheadAttention(nn.Module): | |
| """Multi-headed attention. | |
| See "Attention Is All You Need" for more details. | |
| """ | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| kdim=None, | |
| vdim=None, | |
| dropout=0.0, | |
| bias=True, | |
| add_bias_kv=False, | |
| add_zero_attn=False, | |
| self_attention=False, | |
| encoder_decoder_attention=False, | |
| q_noise=0.0, | |
| qn_block_size=8, | |
| has_relative_attention_bias=False, | |
| scaling_for_att=1.0 | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.kdim = kdim if kdim is not None else embed_dim | |
| self.vdim = vdim if vdim is not None else embed_dim | |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim | |
| self.num_heads = num_heads | |
| self.dropout_module = FairseqDropout( | |
| dropout, module_name=self.__class__.__name__ | |
| ) | |
| self.has_relative_attention_bias = has_relative_attention_bias | |
| self.head_dim = embed_dim // num_heads | |
| assert ( | |
| self.head_dim * num_heads == self.embed_dim | |
| ), "embed_dim must be divisible by num_heads" | |
| self.scaling = self.head_dim ** -0.5 | |
| self.scaling_for_att = scaling_for_att | |
| self.self_attention = self_attention | |
| self.encoder_decoder_attention = encoder_decoder_attention | |
| assert not self.self_attention or self.qkv_same_dim, ( | |
| "Self-attention requires query, key and " "value to be of the same size" | |
| ) | |
| self.k_proj = quant_noise( | |
| nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size | |
| ) | |
| self.v_proj = quant_noise( | |
| nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size | |
| ) | |
| self.q_proj = quant_noise( | |
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size | |
| ) | |
| self.out_proj = quant_noise( | |
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size | |
| ) | |
| if add_bias_kv: | |
| self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) | |
| self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) | |
| else: | |
| self.bias_k = self.bias_v = None | |
| self.add_zero_attn = add_zero_attn | |
| self.reset_parameters() | |
| self.onnx_trace = False | |
| def prepare_for_onnx_export_(self): | |
| self.onnx_trace = True | |
| def reset_parameters(self): | |
| if self.qkv_same_dim: | |
| # Empirically observed the convergence to be much better with | |
| # the scaled initialization | |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) | |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) | |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) | |
| else: | |
| nn.init.xavier_uniform_(self.k_proj.weight) | |
| nn.init.xavier_uniform_(self.v_proj.weight) | |
| nn.init.xavier_uniform_(self.q_proj.weight) | |
| nn.init.xavier_uniform_(self.out_proj.weight) | |
| if self.out_proj.bias is not None: | |
| nn.init.constant_(self.out_proj.bias, 0.0) | |
| if self.bias_k is not None: | |
| nn.init.xavier_normal_(self.bias_k) | |
| if self.bias_v is not None: | |
| nn.init.xavier_normal_(self.bias_v) | |
| def forward( | |
| self, | |
| query, | |
| key: Optional[Tensor], | |
| value: Optional[Tensor], | |
| key_padding_mask: Optional[Tensor] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
| need_weights: bool = True, | |
| static_kv: bool = False, | |
| attn_mask: Optional[Tensor] = None, | |
| before_softmax: bool = False, | |
| need_head_weights: bool = False, | |
| position_bias: Optional[Tensor] = None | |
| ) -> Tuple[Tensor, Optional[Tensor]]: | |
| """Input shape: Time x Batch x Channel | |
| Args: | |
| key_padding_mask (ByteTensor, optional): mask to exclude | |
| keys that are pads, of shape `(batch, src_len)`, where | |
| padding elements are indicated by 1s. | |
| need_weights (bool, optional): return the attention weights, | |
| averaged over heads (default: False). | |
| attn_mask (ByteTensor, optional): typically used to | |
| implement causal attention, where the mask prevents the | |
| attention from looking forward in time (default: None). | |
| before_softmax (bool, optional): return the raw attention | |
| weights and values before the attention softmax. | |
| need_head_weights (bool, optional): return the attention | |
| weights for each head. Implies *need_weights*. Default: | |
| return the average attention weights over all heads. | |
| """ | |
| if need_head_weights: | |
| need_weights = True | |
| is_tpu = query.device.type == "xla" | |
| tgt_len, bsz, embed_dim = query.size() | |
| src_len = tgt_len | |
| assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}" | |
| assert list(query.size()) == [tgt_len, bsz, embed_dim] | |
| if key is not None: | |
| src_len, key_bsz, _ = key.size() | |
| if not torch.jit.is_scripting(): | |
| assert key_bsz == bsz | |
| assert value is not None | |
| assert src_len, bsz == value.shape[:2] | |
| if ( | |
| not self.onnx_trace | |
| and not is_tpu # don't use PyTorch version on TPUs | |
| and incremental_state is None | |
| and not static_kv | |
| # A workaround for quantization to work. Otherwise JIT compilation | |
| # treats bias in linear module as method. | |
| and not torch.jit.is_scripting() | |
| and not self.has_relative_attention_bias | |
| ): | |
| assert key is not None and value is not None | |
| return F.multi_head_attention_forward( | |
| query, | |
| key, | |
| value, | |
| self.embed_dim, | |
| self.num_heads, | |
| torch.empty([0]), | |
| torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), | |
| self.bias_k, | |
| self.bias_v, | |
| self.add_zero_attn, | |
| self.dropout_module.p, | |
| self.out_proj.weight, | |
| self.out_proj.bias, | |
| self.training or self.dropout_module.apply_during_inference, | |
| key_padding_mask, | |
| need_weights, | |
| attn_mask, | |
| use_separate_proj_weight=True, | |
| q_proj_weight=self.q_proj.weight, | |
| k_proj_weight=self.k_proj.weight, | |
| v_proj_weight=self.v_proj.weight, | |
| ) | |
| if incremental_state is not None: | |
| saved_state = self._get_input_buffer(incremental_state) | |
| if saved_state is not None and "prev_key" in saved_state: | |
| # previous time steps are cached - no need to recompute | |
| # key and value if they are static | |
| if static_kv: | |
| assert self.encoder_decoder_attention and not self.self_attention | |
| key = value = None | |
| else: | |
| saved_state = None | |
| if self.self_attention: | |
| q = self.q_proj(query) | |
| k = self.k_proj(query) | |
| v = self.v_proj(query) | |
| elif self.encoder_decoder_attention: | |
| # encoder-decoder attention | |
| q = self.q_proj(query) | |
| if key is None: | |
| assert value is None | |
| k = v = None | |
| else: | |
| k = self.k_proj(key) | |
| v = self.v_proj(key) | |
| else: | |
| assert key is not None and value is not None | |
| q = self.q_proj(query) | |
| k = self.k_proj(key) | |
| v = self.v_proj(value) | |
| q *= self.scaling | |
| q *= (1 / self.scaling_for_att) | |
| if self.bias_k is not None: | |
| assert self.bias_v is not None | |
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) | |
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) | |
| if attn_mask is not None: | |
| attn_mask = torch.cat( | |
| [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 | |
| ) | |
| if key_padding_mask is not None: | |
| key_padding_mask = torch.cat( | |
| [ | |
| key_padding_mask, | |
| key_padding_mask.new_zeros(key_padding_mask.size(0), 1), | |
| ], | |
| dim=1, | |
| ) | |
| q = ( | |
| q.contiguous() | |
| .view(tgt_len, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| if k is not None: | |
| k = ( | |
| k.contiguous() | |
| .view(-1, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| if v is not None: | |
| v = ( | |
| v.contiguous() | |
| .view(-1, bsz * self.num_heads, self.head_dim) | |
| .transpose(0, 1) | |
| ) | |
| if saved_state is not None: | |
| # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) | |
| if "prev_key" in saved_state: | |
| _prev_key = saved_state["prev_key"] | |
| assert _prev_key is not None | |
| prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) | |
| if static_kv: | |
| k = prev_key | |
| else: | |
| assert k is not None | |
| k = torch.cat([prev_key, k], dim=1) | |
| src_len = k.size(1) | |
| if "prev_value" in saved_state: | |
| _prev_value = saved_state["prev_value"] | |
| assert _prev_value is not None | |
| prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) | |
| if static_kv: | |
| v = prev_value | |
| else: | |
| assert v is not None | |
| v = torch.cat([prev_value, v], dim=1) | |
| prev_key_padding_mask: Optional[Tensor] = None | |
| if "prev_key_padding_mask" in saved_state: | |
| prev_key_padding_mask = saved_state["prev_key_padding_mask"] | |
| assert k is not None and v is not None | |
| key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( | |
| key_padding_mask=key_padding_mask, | |
| prev_key_padding_mask=prev_key_padding_mask, | |
| batch_size=bsz, | |
| src_len=k.size(1), | |
| static_kv=static_kv, | |
| ) | |
| saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) | |
| saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) | |
| saved_state["prev_key_padding_mask"] = key_padding_mask | |
| # In this branch incremental_state is never None | |
| assert incremental_state is not None | |
| incremental_state = self._set_input_buffer(incremental_state, saved_state) | |
| assert k is not None | |
| assert k.size(1) == src_len | |
| # This is part of a workaround to get around fork/join parallelism | |
| # not supporting Optional types. | |
| if key_padding_mask is not None and key_padding_mask.dim() == 0: | |
| key_padding_mask = None | |
| if key_padding_mask is not None: | |
| assert key_padding_mask.size(0) == bsz | |
| assert key_padding_mask.size(1) == src_len | |
| if self.add_zero_attn: | |
| assert v is not None | |
| src_len += 1 | |
| k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) | |
| v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) | |
| if attn_mask is not None: | |
| attn_mask = torch.cat( | |
| [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 | |
| ) | |
| if key_padding_mask is not None: | |
| key_padding_mask = torch.cat( | |
| [ | |
| key_padding_mask, | |
| torch.zeros(key_padding_mask.size(0), 1).type_as( | |
| key_padding_mask | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) | |
| attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) | |
| if position_bias is not None: ## first order | |
| ## position_bias: [241, 241, 64] | |
| #print ("attn_weights: ", attn_weights.size()) # [492, 241, 241] | |
| reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64] | |
| #print ("reshape_q: ", reshape_q.size()) | |
| B = torch.matmul(reshape_q, position_bias.transpose(-2, -1)) | |
| #print ("B: ", B.size()) ## [241, 492, 241] | |
| #B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1)) | |
| B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1)) | |
| #print ("B 2: ", B.size()) | |
| attn_weights += B | |
| attn_weights *= self.scaling_for_att | |
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.unsqueeze(0) | |
| if self.onnx_trace: | |
| attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) | |
| attn_weights += attn_mask | |
| if key_padding_mask is not None: | |
| # don't attend to padding symbols | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| if not is_tpu: | |
| attn_weights = attn_weights.masked_fill( | |
| key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), | |
| float("-inf"), | |
| ) | |
| else: | |
| attn_weights = attn_weights.transpose(0, 2) | |
| attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) | |
| attn_weights = attn_weights.transpose(0, 2) | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if self.scaling_for_att > 1.0: | |
| attn_weights = attn_weights - attn_weights.detach().max(dim=-1, keepdim=True)[0] | |
| if before_softmax: | |
| return attn_weights, v | |
| attn_weights_float = softmax( | |
| attn_weights, dim=-1, onnx_trace=self.onnx_trace | |
| ) | |
| attn_weights = attn_weights_float.type_as(attn_weights) | |
| attn_probs = self.dropout_module(attn_weights) | |
| assert v is not None | |
| attn = torch.bmm(attn_probs, v) | |
| assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] | |
| if self.onnx_trace and attn.size(1) == 1: | |
| # when ONNX tracing a single decoder step (sequence length == 1) | |
| # the transpose is a no-op copy before view, thus unnecessary | |
| attn = attn.contiguous().view(tgt_len, bsz, embed_dim) | |
| else: | |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) | |
| attn = self.out_proj(attn) | |
| attn_weights: Optional[Tensor] = None | |
| if need_weights: | |
| attn_weights = attn_weights_float.view( | |
| bsz, self.num_heads, tgt_len, src_len | |
| ).transpose(1, 0) | |
| if not need_head_weights: | |
| # average attention weights over heads | |
| attn_weights = attn_weights.mean(dim=0) | |
| return attn, attn_weights | |
| def _append_prev_key_padding_mask( | |
| key_padding_mask: Optional[Tensor], | |
| prev_key_padding_mask: Optional[Tensor], | |
| batch_size: int, | |
| src_len: int, | |
| static_kv: bool, | |
| ) -> Optional[Tensor]: | |
| # saved key padding masks have shape (bsz, seq_len) | |
| if prev_key_padding_mask is not None and static_kv: | |
| new_key_padding_mask = prev_key_padding_mask | |
| elif prev_key_padding_mask is not None and key_padding_mask is not None: | |
| new_key_padding_mask = torch.cat( | |
| [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 | |
| ) | |
| # During incremental decoding, as the padding token enters and | |
| # leaves the frame, there will be a time when prev or current | |
| # is None | |
| elif prev_key_padding_mask is not None: | |
| if src_len > prev_key_padding_mask.size(1): | |
| filler = torch.zeros( | |
| (batch_size, src_len - prev_key_padding_mask.size(1)), | |
| device=prev_key_padding_mask.device, | |
| ) | |
| new_key_padding_mask = torch.cat( | |
| [prev_key_padding_mask.float(), filler.float()], dim=1 | |
| ) | |
| else: | |
| new_key_padding_mask = prev_key_padding_mask.float() | |
| elif key_padding_mask is not None: | |
| if src_len > key_padding_mask.size(1): | |
| filler = torch.zeros( | |
| (batch_size, src_len - key_padding_mask.size(1)), | |
| device=key_padding_mask.device, | |
| ) | |
| new_key_padding_mask = torch.cat( | |
| [filler.float(), key_padding_mask.float()], dim=1 | |
| ) | |
| else: | |
| new_key_padding_mask = key_padding_mask.float() | |
| else: | |
| new_key_padding_mask = prev_key_padding_mask | |
| return new_key_padding_mask | |
| def reorder_incremental_state( | |
| self, | |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
| new_order: Tensor, | |
| ): | |
| """Reorder buffered internal state (for incremental generation).""" | |
| input_buffer = self._get_input_buffer(incremental_state) | |
| if input_buffer is not None: | |
| for k in input_buffer.keys(): | |
| input_buffer_k = input_buffer[k] | |
| if input_buffer_k is not None: | |
| if self.encoder_decoder_attention and input_buffer_k.size( | |
| 0 | |
| ) == new_order.size(0): | |
| break | |
| input_buffer[k] = input_buffer_k.index_select(0, new_order) | |
| incremental_state = self._set_input_buffer(incremental_state, input_buffer) | |
| return incremental_state | |
| def _get_input_buffer( | |
| self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] | |
| ) -> Dict[str, Optional[Tensor]]: | |
| result = self.get_incremental_state(incremental_state, "attn_state") | |
| if result is not None: | |
| return result | |
| else: | |
| empty_result: Dict[str, Optional[Tensor]] = {} | |
| return empty_result | |
| def _set_input_buffer( | |
| self, | |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | |
| buffer: Dict[str, Optional[Tensor]], | |
| ): | |
| return self.set_incremental_state(incremental_state, "attn_state", buffer) | |
| def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): | |
| return attn_weights | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| prefix = name + "." if name != "" else "" | |
| items_to_add = {} | |
| keys_to_remove = [] | |
| for k in state_dict.keys(): | |
| if k.endswith(prefix + "in_proj_weight"): | |
| # in_proj_weight used to be q + k + v with same dimensions | |
| dim = int(state_dict[k].shape[0] / 3) | |
| items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] | |
| items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] | |
| items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] | |
| keys_to_remove.append(k) | |
| k_bias = prefix + "in_proj_bias" | |
| if k_bias in state_dict.keys(): | |
| dim = int(state_dict[k].shape[0] / 3) | |
| items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] | |
| items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ | |
| dim : 2 * dim | |
| ] | |
| items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] | |
| keys_to_remove.append(prefix + "in_proj_bias") | |
| for k in keys_to_remove: | |
| del state_dict[k] | |
| for key, value in items_to_add.items(): | |
| state_dict[key] = value | |
| class ConvFeatureExtractionModel(nn.Module): | |
| def __init__( | |
| self, | |
| conv_layers: List[Tuple[int, int, int]], | |
| dropout: float = 0.0, | |
| mode: str = "default", | |
| conv_bias: bool = False, | |
| ): | |
| super().__init__() | |
| assert mode in {"default", "layer_norm"} | |
| def block( | |
| n_in, | |
| n_out, | |
| k, | |
| stride, | |
| is_layer_norm=False, | |
| is_group_norm=False, | |
| conv_bias=False, | |
| ): | |
| def make_conv(): | |
| conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) | |
| nn.init.kaiming_normal_(conv.weight) | |
| return conv | |
| assert ( | |
| is_layer_norm and is_group_norm | |
| ) == False, "layer norm and group norm are exclusive" | |
| if is_layer_norm: | |
| return nn.Sequential( | |
| make_conv(), | |
| nn.Dropout(p=dropout), | |
| nn.Sequential( | |
| TransposeLast(), | |
| Fp32LayerNorm(dim, elementwise_affine=True), | |
| TransposeLast(), | |
| ), | |
| nn.GELU(), | |
| ) | |
| elif is_group_norm: | |
| return nn.Sequential( | |
| make_conv(), | |
| nn.Dropout(p=dropout), | |
| Fp32GroupNorm(dim, dim, affine=True), | |
| nn.GELU(), | |
| ) | |
| else: | |
| return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) | |
| in_d = 1 | |
| self.conv_layers = nn.ModuleList() | |
| for i, cl in enumerate(conv_layers): | |
| assert len(cl) == 3, "invalid conv definition: " + str(cl) | |
| (dim, k, stride) = cl | |
| self.conv_layers.append( | |
| block( | |
| in_d, | |
| dim, | |
| k, | |
| stride, | |
| is_layer_norm=mode == "layer_norm", | |
| is_group_norm=mode == "default" and i == 0, | |
| conv_bias=conv_bias, | |
| ) | |
| ) | |
| in_d = dim | |
| def forward(self, x): | |
| # BxT -> BxCxT | |
| x = x.unsqueeze(1) | |
| for conv in self.conv_layers: | |
| x = conv(x) | |
| return x | |
| class TransposeLast(nn.Module): | |
| def __init__(self, deconstruct_idx=None): | |
| super().__init__() | |
| self.deconstruct_idx = deconstruct_idx | |
| def forward(self, x): | |
| if self.deconstruct_idx is not None: | |
| x = x[self.deconstruct_idx] | |
| return x.transpose(-2, -1) | |
| class Fp32GroupNorm(nn.GroupNorm): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input): | |
| output = F.group_norm( | |
| input.float(), | |
| self.num_groups, | |
| self.weight.float() if self.weight is not None else None, | |
| self.bias.float() if self.bias is not None else None, | |
| self.eps, | |
| ) | |
| return output.type_as(input) | |
| class GradMultiply(torch.autograd.Function): | |
| def forward(ctx, x, scale): | |
| ctx.scale = scale | |
| res = x.new(x) | |
| return res | |
| def backward(ctx, grad): | |
| return grad * ctx.scale, None | |
| class Rotate3D(nn.Module): | |
| """ | |
| (T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D) | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return x.permute(1, 2, 0) | |
| class SamePad(nn.Module): | |
| def __init__(self, kernel_size, causal=False): | |
| super().__init__() | |
| if causal: | |
| self.remove = kernel_size - 1 | |
| else: | |
| self.remove = 1 if kernel_size % 2 == 0 else 0 | |
| def forward(self, x): | |
| if self.remove > 0: | |
| x = x[:, :, : -self.remove] | |
| return x | |