in NMT/src/data/loader.py [0:0]
def load_para_data(params, data):
"""
Load parallel data.
"""
assert len(params.para_dataset) > 0
for (lang1, lang2), paths in params.para_dataset.items():
assert lang1 in params.langs and lang2 in params.langs
logger.info('============ Parallel data (%s - %s)' % (lang1, lang2))
datasets = []
for name, path in zip(['train', 'valid', 'test'], paths):
if path == '':
assert name == 'train'
datasets.append((name, None))
continue
assert name != 'train' or params.n_para != 0
# load data
data1 = load_binarized(path.replace('XX', lang1), params)
data2 = load_binarized(path.replace('XX', lang2), params)
set_parameters(params, data1['dico'])
set_parameters(params, data2['dico'])
# set / check dictionaries
if lang1 not in data['dico']:
data['dico'][lang1] = data1['dico']
else:
assert data['dico'][lang1] == data1['dico']
if lang2 not in data['dico']:
data['dico'][lang2] = data2['dico']
else:
assert data['dico'][lang2] == data2['dico']
# parallel data
para_data = ParallelDataset(
data1['sentences'], data1['positions'], data['dico'][lang1], params.lang2id[lang1],
data2['sentences'], data2['positions'], data['dico'][lang2], params.lang2id[lang2],
params
)
# remove too long sentences (train / valid only, test must remain unchanged)
if name != 'test':
para_data.remove_long_sentences(params.max_len)
# select a subset of sentences
if name == 'train' and params.n_para != -1:
para_data.select_data(0, params.n_para)
# if name == 'valid':
# para_data.select_data(0, 100)
# if name == 'test':
# para_data.select_data(0, 167)
datasets.append((name, para_data))
assert (lang1, lang2) not in data['para']
data['para'][(lang1, lang2)] = {k: v for k, v in datasets}
logger.info('')