Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # Copyright (c) 2022 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Based on fairseq code bases | |
| # https://github.com/facebookresearch/fairseq | |
| # -------------------------------------------------------- | |
| """ | |
| Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py | |
| 1. Add custom lang_format in function load_langpair_dataset | |
| 2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset | |
| """ | |
| import itertools | |
| import logging | |
| import os | |
| from fairseq.data import ( | |
| AppendTokenDataset, | |
| LanguagePairDataset, | |
| PrependTokenDataset, | |
| StripTokenDataset, | |
| TruncateDataset, | |
| RandomCropDataset, | |
| data_utils, | |
| indexed_dataset, | |
| ) | |
| from speechut.data.concat_dataset import ConcatDataset | |
| EVAL_BLEU_ORDER = 4 | |
| logger = logging.getLogger(__name__) | |
| def load_langpair_dataset( | |
| data_path, | |
| split, | |
| src, | |
| src_dict, | |
| tgt, | |
| tgt_dict, | |
| combine, | |
| dataset_impl, | |
| upsample_primary, | |
| left_pad_source, | |
| left_pad_target, | |
| max_source_positions, | |
| max_target_positions, | |
| prepend_bos=False, | |
| load_alignments=False, | |
| truncate_source=False, | |
| append_source_id=False, | |
| num_buckets=0, | |
| shuffle=True, | |
| pad_to_multiple=1, | |
| prepend_bos_src=None, | |
| lang_format="[{}]", | |
| input_feeding=True, | |
| ): | |
| def split_exists(split, src, tgt, lang, data_path): | |
| filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) | |
| return indexed_dataset.dataset_exists(filename, impl=dataset_impl) | |
| src_datasets = [] | |
| tgt_datasets = [] | |
| for k in itertools.count(): | |
| split_k = split + (str(k) if k > 0 else "") | |
| # infer langcode | |
| if split_exists(split_k, src, tgt, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) | |
| elif split_exists(split_k, tgt, src, src, data_path): | |
| prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) | |
| else: | |
| if k > 0: | |
| break | |
| else: | |
| raise FileNotFoundError( | |
| "Dataset not found: {} ({})".format(split, data_path) | |
| ) | |
| src_dataset = data_utils.load_indexed_dataset( | |
| prefix + src, src_dict, dataset_impl | |
| ) | |
| if truncate_source: | |
| src_dataset = AppendTokenDataset( | |
| RandomCropDataset( | |
| StripTokenDataset(src_dataset, src_dict.eos()), | |
| max_source_positions - 1, | |
| ), | |
| src_dict.eos(), | |
| ) | |
| src_datasets.append(src_dataset) | |
| tgt_dataset = data_utils.load_indexed_dataset( | |
| prefix + tgt, tgt_dict, dataset_impl | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_datasets.append(tgt_dataset) | |
| logger.info( | |
| "{} {} {}-{} {} examples".format( | |
| data_path, split_k, src, tgt, len(src_datasets[-1]) | |
| ) | |
| ) | |
| if not combine: | |
| break | |
| assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 | |
| if len(src_datasets) == 1: | |
| src_dataset = src_datasets[0] | |
| tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None | |
| else: | |
| sample_ratios = [1] * len(src_datasets) | |
| sample_ratios[0] = upsample_primary | |
| src_dataset = ConcatDataset(src_datasets, sample_ratios) | |
| if len(tgt_datasets) > 0: | |
| tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) | |
| else: | |
| tgt_dataset = None | |
| if prepend_bos: | |
| assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") | |
| src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) | |
| if tgt_dataset is not None: | |
| tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) | |
| elif prepend_bos_src is not None: | |
| logger.info(f"prepending src bos: {prepend_bos_src}") | |
| src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) | |
| eos = None | |
| if append_source_id: | |
| src_dataset = AppendTokenDataset( | |
| src_dataset, src_dict.index(lang_format.format(src)) | |
| ) | |
| if tgt_dataset is not None: | |
| tgt_dataset = AppendTokenDataset( | |
| tgt_dataset, tgt_dict.index(lang_format.format(tgt)) | |
| ) | |
| eos = tgt_dict.index(lang_format.format(tgt)) | |
| align_dataset = None | |
| if load_alignments: | |
| align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) | |
| if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): | |
| align_dataset = data_utils.load_indexed_dataset( | |
| align_path, None, dataset_impl | |
| ) | |
| tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None | |
| return LanguagePairDataset( | |
| src_dataset, | |
| src_dataset.sizes, | |
| src_dict, | |
| tgt_dataset, | |
| tgt_dataset_sizes, | |
| tgt_dict, | |
| left_pad_source=left_pad_source, | |
| left_pad_target=left_pad_target, | |
| align_dataset=align_dataset, | |
| eos=eos, | |
| num_buckets=num_buckets, | |
| shuffle=shuffle, | |
| pad_to_multiple=pad_to_multiple, | |
| input_feeding=input_feeding, | |
| ) | |