in fairseq/tasks/semisupervised_translation.py [0:0]
def load_dataset(self, split, epoch=1, **kwargs):
"""Load a dataset split."""
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
def split_exists(split, src, tgt, lang):
if src is not None:
filename = os.path.join(
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
)
else:
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, src, tgt)
)
return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
def load_indexed_dataset(path, dictionary):
return data_utils.load_indexed_dataset(
path, dictionary, self.args.dataset_impl
)
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
if (
self.lambda_parallel > 0.0
or self.lambda_parallel_steps is not None
or not split.startswith("train")
):
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split("-")
if split_exists(split, src, tgt, src):
prefix = os.path.join(
data_path, "{}.{}-{}.".format(split, src, tgt)
)
elif split_exists(split, tgt, src, src):
prefix = os.path.join(
data_path, "{}.{}-{}.".format(split, tgt, src)
)
else:
continue
src_datasets[lang_pair] = load_indexed_dataset(
prefix + src, self.dicts[src]
)
tgt_datasets[lang_pair] = load_indexed_dataset(
prefix + tgt, self.dicts[tgt]
)
logger.info(
"parallel-{} {} {} examples".format(
data_path, split, len(src_datasets[lang_pair])
)
)
if len(src_datasets) == 0:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
# back translation datasets
backtranslate_datasets = {}
if (
self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None
) and split.startswith("train"):
for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError(
"Dataset not found: backtranslation {} ({})".format(
split, data_path
)
)
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
)
dataset = load_indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
lang_pair_dataset = LanguagePairDataset(
dataset,
dataset.sizes,
src_dict=self.dicts[src],
tgt=dataset,
tgt_sizes=dataset.sizes,
tgt_dict=self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
backtranslate_datasets[lang_pair] = BacktranslationDataset(
tgt_dataset=self.alter_dataset_langtok(
lang_pair_dataset_tgt,
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_lang=src,
),
backtranslation_fn=self.backtranslators[lang_pair],
src_dict=self.dicts[src],
tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
).collater,
)
logger.info(
"backtranslate-{}: {} {} {} examples".format(
tgt,
data_path,
split,
len(backtranslate_datasets[lang_pair]),
)
)
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[
lang_pair
]
# denoising autoencoder
noising_datasets = {}
if (
self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None
) and split.startswith("train"):
for lang_pair in self.lang_pairs:
_, tgt = lang_pair.split("-")
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(
data_path, "{}.{}-None.{}".format(split, tgt, tgt)
)
tgt_dataset1 = load_indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = load_indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
tgt_dataset1,
self.dicts[tgt],
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noising_datasets[lang_pair] = self.alter_dataset_langtok(
LanguagePairDataset(
noising_dataset,
tgt_dataset1.sizes,
self.dicts[tgt],
tgt_dataset2,
tgt_dataset2.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
logger.info(
"denoising-{}: {} {} {} examples".format(
tgt,
data_path,
split,
len(noising_datasets[lang_pair]),
)
)
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split("-")
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset,
src_dataset.sizes,
self.dicts[src],
tgt_dataset,
tgt_dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
self.dicts[src].eos(),
src,
self.dicts[tgt].eos(),
tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict(
[
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in src_datasets.keys()
]
+ [
(_get_bt_dataset_key(lang_pair), dataset)
for lang_pair, dataset in backtranslate_datasets.items()
]
+ [
(_get_denoising_dataset_key(lang_pair), dataset)
for lang_pair, dataset in noising_datasets.items()
]
),
eval_key=None
if self.training
else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)