def load_dataset()

in fairseq/tasks/multilingual_language_modeling.py [0:0]


    def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        languages, data_path = MultilingualLanguageModelingTask._get_langs(
            self.args, epoch
        )
        lang_to_offline_shard_ratio = None
        if self.args.lang_to_offline_shard_ratio != "":
            lang_to_offline_shard_ratio = {}
            assert os.path.exists(
                self.args.lang_to_offline_shard_ratio
            ), "provided offline shard ratio file doesn't exist: {0}".format(
                self.args.lang_to_offline_shard_ratio
            )
            with open(self.args.lang_to_offline_shard_ratio) as fin:
                for line in fin:
                    lang, ratio = line.strip().split("\t")
                    ratio = float(ratio)
                    lang_to_offline_shard_ratio[lang] = ratio

            logger.info(
                "Found offline sharded ratio: %s",
                lang_to_offline_shard_ratio,
            )

        if split == self.args.train_subset:
            logger.info(
                "Training on {0} languages: {1}".format(len(languages), languages)
            )
        else:
            logger.info(
                "Evaluating on {0} languages: {1}".format(len(languages), languages)
            )

        tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)

        fixed_pad_length = None
        if self.args.pad_to_fixed_length:
            fixed_pad_length = self.args.tokens_per_sample

        pad_to_bsz = None
        if self.args.pad_to_fixed_bsz:
            pad_to_bsz = (
                self.args.batch_size_valid if "valid" in split else self.args.batch_size
            )

        lang_datasets = []
        for lang_id, language in enumerate(languages):
            split_path = os.path.join(data_path, language, split)
            dataset = data_utils.load_indexed_dataset(
                split_path, self.dictionary, self.args.dataset_impl, combine=combine
            )
            # print('len(dataset) =', len(dataset))
            if dataset is None:
                raise FileNotFoundError(
                    "Dataset not found: {} ({})".format(split, split_path)
                )

            dataset = maybe_shorten_dataset(
                dataset,
                split,
                self.args.shorten_data_split_list,
                self.args.shorten_method,
                tokens_per_sample,
                self.args.seed,
            )

            dataset = TokenBlockDataset(
                dataset,
                dataset.sizes,
                tokens_per_sample,
                pad=self.dictionary.pad(),
                eos=self.dictionary.eos(),
                break_mode=self.args.sample_break_mode,
                include_targets=True,
            )

            add_eos_for_other_targets = (
                self.args.sample_break_mode is not None
                and self.args.sample_break_mode != "none"
            )
            src_lang_idx, tgt_lang_idx = None, None
            if self.args.add_bos_token:
                src_lang_idx = self.dictionary.index(lang_token(language))
                tgt_lang_idx = self.output_dictionary.index(lang_token(language))

            lang_datasets.append(
                MonolingualDataset(
                    dataset=dataset,
                    sizes=dataset.sizes,
                    src_vocab=self.dictionary,
                    tgt_vocab=self.output_dictionary,
                    add_eos_for_other_targets=add_eos_for_other_targets,
                    shuffle=True,
                    targets=self.targets,
                    fixed_pad_length=fixed_pad_length,
                    pad_to_bsz=pad_to_bsz,
                    add_bos_token=self.args.add_bos_token,
                    src_lang_idx=src_lang_idx,
                    tgt_lang_idx=tgt_lang_idx,
                )
            )

        dataset_lengths = np.array(
            [len(d) for d in lang_datasets],
            dtype=float,
        )
        logger.info(
            "loaded total {} blocks for all languages".format(
                dataset_lengths.sum(),
            )
        )
        if split == self.args.train_subset:
            dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
            if lang_to_offline_shard_ratio is not None:
                dataset_lengths_ratio_multiplier = []
                for lang in languages:
                    assert (
                        lang in lang_to_offline_shard_ratio
                    ), "Lang: {0} missing in offline shard ratio file: {1}".format(
                        lang,
                        self.args.lang_to_offline_shard_ratio,
                    )
                    dataset_lengths_ratio_multiplier.append(
                        lang_to_offline_shard_ratio[lang]
                    )
                dataset_lengths_ratio_multiplier = np.array(
                    dataset_lengths_ratio_multiplier
                )
                true_dataset_lengths = (
                    dataset_lengths * dataset_lengths_ratio_multiplier
                )
            else:
                true_dataset_lengths = dataset_lengths
            # For train subset, additionally up or down sample languages.
            sample_probs = self._get_sample_prob(true_dataset_lengths)

            logger.info(
                "Sample probability by language: %s",
                {
                    lang: "{0:.4f}".format(sample_probs[id])
                    for id, lang in enumerate(languages)
                },
            )
            size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
            # TODO: add an option for shrinking all size ratios to below 1
            # if self.args.multilang_sampling_alpha != 1:
            #   size_ratio /= size_ratio.max()

            # Fix numeric errors in size ratio computation
            #   0.999999999999999999 -> 1
            #   1.000000000000000002 -> 1
            for i in range(len(size_ratio)):
                size_ratio[i] = round(size_ratio[i], 8)

            logger.info(
                "Up/Down Sampling ratio by language: %s",
                {
                    lang: "{0:.2f}".format(size_ratio[id])
                    for id, lang in enumerate(languages)
                },
            )
            logger.info(
                "Actual dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            resampled_lang_datasets = [
                ResamplingDataset(
                    lang_datasets[i],
                    size_ratio=size_ratio[i],
                    seed=self.args.seed,
                    epoch=epoch,
                    replace=size_ratio[i] > 1.0,
                )
                for i, d in enumerate(lang_datasets)
            ]
            logger.info(
                "Resampled dataset size by language: %s",
                {
                    lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
                    for id, lang in enumerate(languages)
                },
            )
            dataset = ConcatDataset(resampled_lang_datasets)
        else:
            dataset = ConcatDataset(lang_datasets)
            lang_splits = [split]
            for lang_id, lang_dataset in enumerate(lang_datasets):
                split_name = split + "_" + languages[lang_id]
                lang_splits.append(split_name)
                self.datasets[split_name] = lang_dataset

            # [TODO]: This is hacky for now to print validation ppl for each
            # language individually. Maybe need task API changes to allow it
            # in more generic ways.
            if split in self.args.valid_subset:
                self.args.valid_subset = self.args.valid_subset.replace(
                    split, ",".join(lang_splits)
                )

        with data_utils.numpy_seed(self.args.seed + epoch):
            shuffle = np.random.permutation(len(dataset))

        self.datasets[split] = SortDataset(
            dataset,
            sort_order=[
                shuffle,
                dataset.sizes,
            ],
        )