in pytorch_translate/dual_learning/dual_learning_task.py [0:0]
def load_dataset(self, split, seed=None):
"""Load split, which is train (monolingual data, optional parallel data),
or eval (always parallel data).
"""
if split == self.args.valid_subset:
# tune set is always parallel
primal_parallel, _, _ = data_utils.load_parallel_dataset(
source_lang=self.source_lang,
target_lang=self.target_lang,
src_bin_path=self.args.forward_eval_source_binary_path,
tgt_bin_path=self.args.forward_eval_target_binary_path,
source_dictionary=self.primal_src_dict,
target_dictionary=self.primal_tgt_dict,
split=split,
remove_eos_from_source=not self.args.append_eos_to_source,
append_eos_to_target=True,
char_source_dict=None,
log_verbose=self.args.log_verbose,
)
# now just flip the source and target
dual_parallel, _, _ = data_utils.load_parallel_dataset(
source_lang=self.target_lang,
target_lang=self.source_lang,
src_bin_path=self.args.backward_eval_source_binary_path,
tgt_bin_path=self.args.backward_eval_target_binary_path,
source_dictionary=self.dual_src_dict,
target_dictionary=self.dual_tgt_dict,
split=split,
remove_eos_from_source=not self.args.append_eos_to_source,
append_eos_to_target=True,
char_source_dict=None,
log_verbose=self.args.log_verbose,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
("primal_parallel", primal_parallel),
("dual_parallel", dual_parallel),
]
)
)
elif split == self.args.train_subset:
src_dataset = data_utils.load_monolingual_dataset(
self.args.train_mono_source_binary_path, is_source=True
)
tgt_dataset = data_utils.load_monolingual_dataset(
self.args.train_mono_target_binary_path, is_source=True
)
primal_source_mono = LanguagePairDataset(
src=src_dataset,
src_sizes=src_dataset.sizes,
src_dict=self.primal_src_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
dual_source_mono = LanguagePairDataset(
src=tgt_dataset,
src_sizes=tgt_dataset.sizes,
src_dict=self.dual_src_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
primal_parallel, _, _ = data_utils.load_parallel_dataset(
source_lang=self.source_lang,
target_lang=self.target_lang,
src_bin_path=self.args.forward_train_source_binary_path,
tgt_bin_path=self.args.forward_train_target_binary_path,
source_dictionary=self.primal_src_dict,
target_dictionary=self.primal_tgt_dict,
split=split,
remove_eos_from_source=not self.args.append_eos_to_source,
append_eos_to_target=True,
char_source_dict=None,
log_verbose=self.args.log_verbose,
)
dual_parallel, _, _ = data_utils.load_parallel_dataset(
source_lang=self.target_lang,
target_lang=self.source_lang,
src_bin_path=self.args.backward_train_source_binary_path,
tgt_bin_path=self.args.backward_train_target_binary_path,
source_dictionary=self.dual_src_dict,
target_dictionary=self.dual_tgt_dict,
split=split,
remove_eos_from_source=not self.args.append_eos_to_source,
append_eos_to_target=True,
char_source_dict=None,
log_verbose=self.args.log_verbose,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
("primal_parallel", primal_parallel),
("dual_parallel", dual_parallel),
("primal_source", primal_source_mono),
("dual_source", dual_source_mono),
]
)
)
else:
raise ValueError("Invalid data split.")