def load_dataset()

in pytorch_translate/dual_learning/dual_learning_task.py [0:0]


    def load_dataset(self, split, seed=None):
        """Load split, which is train (monolingual data, optional parallel data),
        or eval (always parallel data).
        """
        if split == self.args.valid_subset:
            # tune set is always parallel
            primal_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.source_lang,
                target_lang=self.target_lang,
                src_bin_path=self.args.forward_eval_source_binary_path,
                tgt_bin_path=self.args.forward_eval_target_binary_path,
                source_dictionary=self.primal_src_dict,
                target_dictionary=self.primal_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            # now just flip the source and target
            dual_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.target_lang,
                target_lang=self.source_lang,
                src_bin_path=self.args.backward_eval_source_binary_path,
                tgt_bin_path=self.args.backward_eval_target_binary_path,
                source_dictionary=self.dual_src_dict,
                target_dictionary=self.dual_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict(
                    [
                        ("primal_parallel", primal_parallel),
                        ("dual_parallel", dual_parallel),
                    ]
                )
            )
        elif split == self.args.train_subset:
            src_dataset = data_utils.load_monolingual_dataset(
                self.args.train_mono_source_binary_path, is_source=True
            )
            tgt_dataset = data_utils.load_monolingual_dataset(
                self.args.train_mono_target_binary_path, is_source=True
            )
            primal_source_mono = LanguagePairDataset(
                src=src_dataset,
                src_sizes=src_dataset.sizes,
                src_dict=self.primal_src_dict,
                tgt=None,
                tgt_sizes=None,
                tgt_dict=None,
            )
            dual_source_mono = LanguagePairDataset(
                src=tgt_dataset,
                src_sizes=tgt_dataset.sizes,
                src_dict=self.dual_src_dict,
                tgt=None,
                tgt_sizes=None,
                tgt_dict=None,
            )

            primal_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.source_lang,
                target_lang=self.target_lang,
                src_bin_path=self.args.forward_train_source_binary_path,
                tgt_bin_path=self.args.forward_train_target_binary_path,
                source_dictionary=self.primal_src_dict,
                target_dictionary=self.primal_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            dual_parallel, _, _ = data_utils.load_parallel_dataset(
                source_lang=self.target_lang,
                target_lang=self.source_lang,
                src_bin_path=self.args.backward_train_source_binary_path,
                tgt_bin_path=self.args.backward_train_target_binary_path,
                source_dictionary=self.dual_src_dict,
                target_dictionary=self.dual_tgt_dict,
                split=split,
                remove_eos_from_source=not self.args.append_eos_to_source,
                append_eos_to_target=True,
                char_source_dict=None,
                log_verbose=self.args.log_verbose,
            )
            self.datasets[split] = RoundRobinZipDatasets(
                OrderedDict(
                    [
                        ("primal_parallel", primal_parallel),
                        ("dual_parallel", dual_parallel),
                        ("primal_source", primal_source_mono),
                        ("dual_source", dual_source_mono),
                    ]
                )
            )
        else:
            raise ValueError("Invalid data split.")