in src/fairseq/fairseq/tasks/sentence_ranking.py [0:0]
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(type, split):
return os.path.join(self.args.data, type, split)
def make_dataset(type, dictionary):
split_path = get_path(type, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.source_dictionary,
self.args.dataset_impl,
combine=combine,
)
return dataset
input0 = make_dataset('input0', self.source_dictionary)
input_options = [
make_dataset(
'input{idx}'.format(idx=idx + 1),
self.source_dictionary
)
for idx in range(self.args.num_classes)
]
if self.args.separator_token is not None:
input0 = PrependTokenDataset(input0, self.args.separator_token)
src_tokens = []
for input_option in input_options:
if self.args.init_token is not None:
input_option = PrependTokenDataset(input_option, self.args.init_token)
if self.args.max_option_length is not None:
input_option = TruncateDataset(input_option, self.args.max_option_length)
src_token = ConcatSentencesDataset(input_option, input0)
src_token = maybe_shorten_dataset(
src_token,
split,
self.args.shorten_data_split_whitelist,
self.args.shorten_method,
self.args.max_positions,
self.args.seed,
)
src_tokens.append(src_token)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens[0]))
dataset = {
'id': IdDataset(),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens[0], reduce=True),
}
for src_token_idx in range(len(src_tokens)):
dataset.update(
{
'net_input{idx}'.format(idx=src_token_idx+1): {
'src_tokens': RightPadDataset(
src_tokens[src_token_idx],
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': NumelDataset(src_tokens[src_token_idx], reduce=False),
}
}
)
label_path = '{}.label'.format(get_path('label', split))
if os.path.exists(label_path):
with open(label_path) as h:
dataset.update(
target=RawLabelDataset([
int(x.strip()) for x in h.readlines()
])
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]