in src/fairseq/fairseq/tasks/sentence_prediction.py [0:0]
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
def get_path(type, split):
return os.path.join(self.args.data, type, split)
def make_dataset(type, dictionary):
split_path = get_path(type, split)
dataset = data_utils.load_indexed_dataset(
split_path,
dictionary,
self.args.dataset_impl,
combine=combine,
)
return dataset
input0 = make_dataset('input0', self.source_dictionary)
assert input0 is not None, 'could not find dataset: {}'.format(get_path(type, split))
input1 = make_dataset('input1', self.source_dictionary)
if self.args.init_token is not None:
input0 = PrependTokenDataset(input0, self.args.init_token)
if input1 is None:
src_tokens = input0
else:
if self.args.separator_token is not None:
input1 = PrependTokenDataset(input1, self.args.separator_token)
src_tokens = ConcatSentencesDataset(input0, input1)
with data_utils.numpy_seed(self.args.seed):
shuffle = np.random.permutation(len(src_tokens))
src_tokens = maybe_shorten_dataset(
src_tokens,
split,
self.args.shorten_data_split_whitelist,
self.args.shorten_method,
self.args.max_positions,
self.args.seed,
)
dataset = {
'id': IdDataset(),
'net_input': {
'src_tokens': RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
'src_lengths': NumelDataset(src_tokens, reduce=False),
},
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_tokens, reduce=True),
}
if self.args.add_prev_output_tokens:
prev_tokens_dataset = RightPadDataset(
RollDataset(src_tokens, 1),
pad_idx=self.dictionary.pad(),
)
dataset['net_input'].update(
prev_output_tokens=prev_tokens_dataset,
)
if not self.args.regression_target:
label_dataset = make_dataset('label', self.label_dictionary)
if label_dataset is not None:
dataset.update(
target=OffsetTokensDataset(
StripTokenDataset(
label_dataset,
id_to_strip=self.label_dictionary.eos(),
),
offset=-self.label_dictionary.nspecial,
)
)
else:
label_path = "{0}.label".format(get_path('label', split))
if os.path.exists(label_path):
def parse_regression_target(i, line):
values = line.split()
assert len(values) == self.args.num_classes, \
f'expected num_classes={self.args.num_classes} regression target values on line {i}, found: "{line}"'
return [float(x) for x in values]
dataset.update(
target=RawLabelDataset([
parse_regression_target(i, line.strip()) for i, line in enumerate(open(label_path).readlines())
])
)
nested_dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
if self.args.no_shuffle:
dataset = nested_dataset
else:
dataset = SortDataset(
nested_dataset,
# shuffle
sort_order=[shuffle],
)
logger.info("Loaded {0} with #samples: {1}".format(split, len(dataset)))
self.datasets[split] = dataset
return self.datasets[split]