in fairseq/tasks/multilingual_language_modeling.py [0:0]
def load_dataset(self, split: str, epoch=1, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
languages, data_path = MultilingualLanguageModelingTask._get_langs(
self.args, epoch
)
lang_to_offline_shard_ratio = None
if self.args.lang_to_offline_shard_ratio != "":
lang_to_offline_shard_ratio = {}
assert os.path.exists(
self.args.lang_to_offline_shard_ratio
), "provided offline shard ratio file doesn't exist: {0}".format(
self.args.lang_to_offline_shard_ratio
)
with open(self.args.lang_to_offline_shard_ratio) as fin:
for line in fin:
lang, ratio = line.strip().split("\t")
ratio = float(ratio)
lang_to_offline_shard_ratio[lang] = ratio
logger.info(
"Found offline sharded ratio: %s",
lang_to_offline_shard_ratio,
)
if split == self.args.train_subset:
logger.info(
"Training on {0} languages: {1}".format(len(languages), languages)
)
else:
logger.info(
"Evaluating on {0} languages: {1}".format(len(languages), languages)
)
tokens_per_sample = self.args.tokens_per_sample - int(self.args.add_bos_token)
fixed_pad_length = None
if self.args.pad_to_fixed_length:
fixed_pad_length = self.args.tokens_per_sample
pad_to_bsz = None
if self.args.pad_to_fixed_bsz:
pad_to_bsz = (
self.args.batch_size_valid if "valid" in split else self.args.batch_size
)
lang_datasets = []
for lang_id, language in enumerate(languages):
split_path = os.path.join(data_path, language, split)
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
# print('len(dataset) =', len(dataset))
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = maybe_shorten_dataset(
dataset,
split,
self.args.shorten_data_split_list,
self.args.shorten_method,
tokens_per_sample,
self.args.seed,
)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
)
add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
src_lang_idx, tgt_lang_idx = None, None
if self.args.add_bos_token:
src_lang_idx = self.dictionary.index(lang_token(language))
tgt_lang_idx = self.output_dictionary.index(lang_token(language))
lang_datasets.append(
MonolingualDataset(
dataset=dataset,
sizes=dataset.sizes,
src_vocab=self.dictionary,
tgt_vocab=self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
fixed_pad_length=fixed_pad_length,
pad_to_bsz=pad_to_bsz,
add_bos_token=self.args.add_bos_token,
src_lang_idx=src_lang_idx,
tgt_lang_idx=tgt_lang_idx,
)
)
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
logger.info(
"loaded total {} blocks for all languages".format(
dataset_lengths.sum(),
)
)
if split == self.args.train_subset:
dataset_lengths_ratio_multiplier = np.ones(len(dataset_lengths))
if lang_to_offline_shard_ratio is not None:
dataset_lengths_ratio_multiplier = []
for lang in languages:
assert (
lang in lang_to_offline_shard_ratio
), "Lang: {0} missing in offline shard ratio file: {1}".format(
lang,
self.args.lang_to_offline_shard_ratio,
)
dataset_lengths_ratio_multiplier.append(
lang_to_offline_shard_ratio[lang]
)
dataset_lengths_ratio_multiplier = np.array(
dataset_lengths_ratio_multiplier
)
true_dataset_lengths = (
dataset_lengths * dataset_lengths_ratio_multiplier
)
else:
true_dataset_lengths = dataset_lengths
# For train subset, additionally up or down sample languages.
sample_probs = self._get_sample_prob(true_dataset_lengths)
logger.info(
"Sample probability by language: %s",
{
lang: "{0:.4f}".format(sample_probs[id])
for id, lang in enumerate(languages)
},
)
size_ratio = (sample_probs * true_dataset_lengths.sum()) / dataset_lengths
# TODO: add an option for shrinking all size ratios to below 1
# if self.args.multilang_sampling_alpha != 1:
# size_ratio /= size_ratio.max()
# Fix numeric errors in size ratio computation
# 0.999999999999999999 -> 1
# 1.000000000000000002 -> 1
for i in range(len(size_ratio)):
size_ratio[i] = round(size_ratio[i], 8)
logger.info(
"Up/Down Sampling ratio by language: %s",
{
lang: "{0:.2f}".format(size_ratio[id])
for id, lang in enumerate(languages)
},
)
logger.info(
"Actual dataset size by language: %s",
{
lang: "{0:.2f}".format(len(lang_datasets[id]))
for id, lang in enumerate(languages)
},
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
epoch=epoch,
replace=size_ratio[i] > 1.0,
)
for i, d in enumerate(lang_datasets)
]
logger.info(
"Resampled dataset size by language: %s",
{
lang: "{0:.2f}".format(len(resampled_lang_datasets[id]))
for id, lang in enumerate(languages)
},
)
dataset = ConcatDataset(resampled_lang_datasets)
else:
dataset = ConcatDataset(lang_datasets)
lang_splits = [split]
for lang_id, lang_dataset in enumerate(lang_datasets):
split_name = split + "_" + languages[lang_id]
lang_splits.append(split_name)
self.datasets[split_name] = lang_dataset
# [TODO]: This is hacky for now to print validation ppl for each
# language individually. Maybe need task API changes to allow it
# in more generic ways.
if split in self.args.valid_subset:
self.args.valid_subset = self.args.valid_subset.replace(
split, ",".join(lang_splits)
)
with data_utils.numpy_seed(self.args.seed + epoch):
shuffle = np.random.permutation(len(dataset))
self.datasets[split] = SortDataset(
dataset,
sort_order=[
shuffle,
dataset.sizes,
],
)