Spaces:
Build error
Build error
| # ---------------------------------------------------------------------------- | |
| # SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730) | |
| # Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT | |
| # Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
| # | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # ---------------------------------------------------------------------------- | |
| import contextlib | |
| import torch | |
| from dataclasses import dataclass, field | |
| from fairseq import utils | |
| from fairseq.models import BaseFairseqModel, register_model | |
| from fairseq.models.fairseq_encoder import FairseqEncoder | |
| from fairseq.models.hubert import HubertAsrConfig, HubertEncoder | |
| from fairseq.tasks import FairseqTask | |
| class SpeechUTASRConfig(HubertAsrConfig): | |
| add_decoder: bool = field( | |
| default=True, | |
| metadata={"help": "add decoder for fine-tune"}, | |
| ) | |
| class SpeechUTASR(BaseFairseqModel): | |
| """ | |
| A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model | |
| """ | |
| def __init__(self, cfg: SpeechUTASRConfig, encoder: FairseqEncoder): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.encoder = encoder | |
| if not cfg.add_decoder: | |
| self.encoder.w2v_model.decoder = None | |
| def upgrade_state_dict_named(self, state_dict, name): | |
| super().upgrade_state_dict_named(state_dict, name) | |
| return state_dict | |
| def build_model(cls, cfg: SpeechUTASRConfig, task: FairseqTask): | |
| """Build a new model instance.""" | |
| encoder = SpeechUTEncoder(cfg, task) | |
| return cls(cfg, encoder) | |
| def forward(self, source, padding_mask, prev_output_tokens, **kwargs): | |
| encoder_out = self.encoder(source, padding_mask, **kwargs) | |
| x = self.encoder.final_dropout(encoder_out['encoder_out'][0]) # (T, B, C) | |
| if self.encoder.proj: | |
| x = self.encoder.proj(x) | |
| if self.encoder.conv_ctc_proj: | |
| padding_mask = self.encoder.w2v_model.downsample_ctc_padding_mask(encoder_out["encoder_padding_mask"][0]) | |
| else: | |
| padding_mask = encoder_out["encoder_padding_mask"] | |
| decoder_out = self.decoder( | |
| prev_output_tokens, encoder_out=encoder_out, **kwargs | |
| ) if self.cfg.add_decoder else None | |
| return { | |
| "encoder_out_ctc": x, # (T, B, C), for CTC loss | |
| "padding_mask": padding_mask, # (B, T), for CTC loss | |
| "decoder_out": decoder_out, # for ED loss | |
| } | |
| def forward_decoder(self, prev_output_tokens, **kwargs): | |
| return self.decoder(prev_output_tokens, **kwargs) | |
| def get_logits(self, net_output): | |
| """For CTC decoding""" | |
| logits = net_output["encoder_out"] | |
| padding = net_output["encoder_padding_mask"] | |
| if padding is not None and padding.any(): | |
| padding = padding.T | |
| logits[padding][..., 0] = 0 | |
| logits[padding][..., 1:] = float("-inf") | |
| return logits | |
| def get_normalized_probs(self, net_output, log_probs, sample=None): | |
| """For 1) computing CTC loss, 2) decoder decoding.""" | |
| if "encoder_out_ctc" in net_output: | |
| logits = net_output["encoder_out_ctc"] | |
| else: | |
| return self.decoder.get_normalized_probs(net_output, log_probs, sample) | |
| if isinstance(logits, list): | |
| logits = logits[0] | |
| if log_probs: | |
| return utils.log_softmax(logits.float(), dim=-1) | |
| else: | |
| return utils.softmax(logits.float(), dim=-1) | |
| def decoder(self): | |
| return self.encoder.w2v_model.decoder | |
| class SpeechUTEncoder(HubertEncoder): | |
| """ | |
| Modified from fairseq.models.hubert.hubert_asr.HubertEncoder | |
| 1. make it compatible with encoder-decoder model | |
| """ | |
| def __init__(self, cfg: HubertAsrConfig, task): | |
| super().__init__(cfg, task) | |
| if (task.target_dictionary is not None) and ( | |
| hasattr(self.w2v_model, "unit_encoder_ctc_head") | |
| ): | |
| self.proj = self.w2v_model.unit_encoder_ctc_head | |
| self.conv_ctc_proj = True | |
| else: | |
| self.conv_ctc_proj = False | |
| def forward(self, source, padding_mask, tbc=True, **kwargs): | |
| w2v_args = { | |
| "source": source, | |
| "padding_mask": padding_mask, | |
| "mask": self.apply_mask and self.training, | |
| } | |
| ft = self.freeze_finetune_updates <= self.num_updates | |
| with torch.no_grad() if not ft else contextlib.ExitStack(): | |
| x, padding_mask = self.w2v_model.extract_features(**w2v_args) | |
| if tbc: | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| return { | |
| "encoder_out": [x], # T x B x C | |
| "encoder_padding_mask": [padding_mask], # B x T | |
| } | |
| def forward_torchscript(self, net_input): | |
| """A TorchScript-compatible version of forward. | |
| Forward the encoder out. | |
| """ | |
| x, padding_mask = self.w2v_model.extract_features(**net_input, mask=False) | |
| # B x T x C -> T x B x C | |
| x = x.transpose(0, 1) | |
| encoder_out = { | |
| "encoder_out" : [x], | |
| "encoder_padding_mask" : [padding_mask], | |
| } | |
| if self.proj: | |
| x = self.proj(x) | |
| encoder_out["encoder_out_ctc"] = x | |
| return encoder_out | |
| def reorder_encoder_out(self, encoder_out, new_order): | |
| if encoder_out["encoder_out"] is not None: | |
| encoder_out["encoder_out"] = [ | |
| x.index_select(1, new_order) for x in encoder_out["encoder_out"] | |
| ] | |
| if encoder_out["encoder_padding_mask"] is not None: | |
| encoder_out["encoder_padding_mask"] = [ | |
| x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"] | |
| ] | |
| return encoder_out | |