def load_dataset()

in fairseq/tasks/semisupervised_translation.py [0:0]


    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]

        def split_exists(split, src, tgt, lang):
            if src is not None:
                filename = os.path.join(
                    data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
                )
            else:
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, src, tgt)
                )
            return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)

        def load_indexed_dataset(path, dictionary):
            return data_utils.load_indexed_dataset(
                path, dictionary, self.args.dataset_impl
            )

        # load parallel datasets
        src_datasets, tgt_datasets = {}, {}
        if (
            self.lambda_parallel > 0.0
            or self.lambda_parallel_steps is not None
            or not split.startswith("train")
        ):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if split_exists(split, src, tgt, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, src, tgt)
                    )
                elif split_exists(split, tgt, src, src):
                    prefix = os.path.join(
                        data_path, "{}.{}-{}.".format(split, tgt, src)
                    )
                else:
                    continue
                src_datasets[lang_pair] = load_indexed_dataset(
                    prefix + src, self.dicts[src]
                )
                tgt_datasets[lang_pair] = load_indexed_dataset(
                    prefix + tgt, self.dicts[tgt]
                )
                logger.info(
                    "parallel-{} {} {} examples".format(
                        data_path, split, len(src_datasets[lang_pair])
                    )
                )
            if len(src_datasets) == 0:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, data_path)
                )

        # back translation datasets
        backtranslate_datasets = {}
        if (
            self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                src, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    raise FileNotFoundError(
                        "Dataset not found: backtranslation {} ({})".format(
                            split, data_path
                        )
                    )
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                dataset = load_indexed_dataset(filename, self.dicts[tgt])
                lang_pair_dataset_tgt = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                lang_pair_dataset = LanguagePairDataset(
                    dataset,
                    dataset.sizes,
                    src_dict=self.dicts[src],
                    tgt=dataset,
                    tgt_sizes=dataset.sizes,
                    tgt_dict=self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                )
                backtranslate_datasets[lang_pair] = BacktranslationDataset(
                    tgt_dataset=self.alter_dataset_langtok(
                        lang_pair_dataset_tgt,
                        src_eos=self.dicts[tgt].eos(),
                        src_lang=tgt,
                        tgt_lang=src,
                    ),
                    backtranslation_fn=self.backtranslators[lang_pair],
                    src_dict=self.dicts[src],
                    tgt_dict=self.dicts[tgt],
                    output_collater=self.alter_dataset_langtok(
                        lang_pair_dataset=lang_pair_dataset,
                        src_eos=self.dicts[src].eos(),
                        src_lang=src,
                        tgt_eos=self.dicts[tgt].eos(),
                        tgt_lang=tgt,
                    ).collater,
                )
                logger.info(
                    "backtranslate-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(backtranslate_datasets[lang_pair]),
                    )
                )
                self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
                    lang_pair
                ]

        # denoising autoencoder
        noising_datasets = {}
        if (
            self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
        ) and split.startswith("train"):
            for lang_pair in self.lang_pairs:
                _, tgt = lang_pair.split("-")
                if not split_exists(split, tgt, None, tgt):
                    continue
                filename = os.path.join(
                    data_path, "{}.{}-None.{}".format(split, tgt, tgt)
                )
                tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
                tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
                noising_dataset = NoisingDataset(
                    tgt_dataset1,
                    self.dicts[tgt],
                    seed=1,
                    max_word_shuffle_distance=self.args.max_word_shuffle_distance,
                    word_dropout_prob=self.args.word_dropout_prob,
                    word_blanking_prob=self.args.word_blanking_prob,
                )
                noising_datasets[lang_pair] = self.alter_dataset_langtok(
                    LanguagePairDataset(
                        noising_dataset,
                        tgt_dataset1.sizes,
                        self.dicts[tgt],
                        tgt_dataset2,
                        tgt_dataset2.sizes,
                        self.dicts[tgt],
                        left_pad_source=self.args.left_pad_source,
                        left_pad_target=self.args.left_pad_target,
                    ),
                    src_eos=self.dicts[tgt].eos(),
                    src_lang=tgt,
                    tgt_eos=self.dicts[tgt].eos(),
                    tgt_lang=tgt,
                )
                logger.info(
                    "denoising-{}: {} {} {} examples".format(
                        tgt,
                        data_path,
                        split,
                        len(noising_datasets[lang_pair]),
                    )
                )

        def language_pair_dataset(lang_pair):
            src, tgt = lang_pair.split("-")
            src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
            return self.alter_dataset_langtok(
                LanguagePairDataset(
                    src_dataset,
                    src_dataset.sizes,
                    self.dicts[src],
                    tgt_dataset,
                    tgt_dataset.sizes,
                    self.dicts[tgt],
                    left_pad_source=self.args.left_pad_source,
                    left_pad_target=self.args.left_pad_target,
                ),
                self.dicts[src].eos(),
                src,
                self.dicts[tgt].eos(),
                tgt,
            )

        self.datasets[split] = RoundRobinZipDatasets(
            OrderedDict(
                [
                    (lang_pair, language_pair_dataset(lang_pair))
                    for lang_pair in src_datasets.keys()
                ]
                + [
                    (_get_bt_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in backtranslate_datasets.items()
                ]
                + [
                    (_get_denoising_dataset_key(lang_pair), dataset)
                    for lang_pair, dataset in noising_datasets.items()
                ]
            ),
            eval_key=None
            if self.training
            else "%s-%s" % (self.args.source_lang, self.args.target_lang),
        )