in code/src/data/loader.py [0:0]
def load_binarized(path, params):
"""
Load a binarized dataset and log main statistics.
"""
if path in loaded_data:
logger.info("Reloading data loaded from %s ..." % path)
return loaded_data[path]
assert os.path.isfile(path), path
logger.info("Loading data from %s ..." % path)
data = torch.load(path)
data['positions'] = data['positions'].numpy()
logger.info("%i words (%i unique) in %i sentences with %i attributes. %i unknown words (%i unique)." % (
len(data['sentences']) - len(data['positions']),
len(data['dico']), len(data['positions']),
len(data['attr_values']),
sum(data['unk_words'].values()), len(data['unk_words'])
))
# add length attribute if required
len_attrs = [attr for attr in params.attributes if attr.startswith('length_')]
assert len(len_attrs) <= 1
if len(len_attrs) == 1:
len_attr = len_attrs[0]
assert len_attr[len('length_'):].isdigit()
bs = int(len_attr[len('length_'):])
lm = params.max_len
assert bs >= 1 and lm >= 1 and len_attr not in data['attr_values']
sr = np.arange(0, lm + 1 if lm % bs == 0 else lm + bs - lm % bs + 1, bs)
len_labels = np.ceil((data['positions'][:, 1] - data['positions'][:, 0]).astype(np.float32) / bs) - 1
len_labels = np.minimum(len_labels.astype(np.int64), len(sr) - 2)
assert len_labels.min() >= 0
data['attr_values'][len_attr] = ['%s-%s' % (sr[i], sr[i + 1]) for i in range(len(sr) - 1)]
params.size_ranges = sr
params.bucket_size = bs
else:
len_attr = None
# maximum vocabulary size
if params.max_vocab != -1:
assert params.max_vocab > 0
logger.info("Selecting %i most frequent words ..." % params.max_vocab)
data['dico'].prune(params.max_vocab)
data['sentences'].masked_fill_((data['sentences'] >= params.max_vocab), data['dico'].index(UNK_WORD))
unk_count = (data['sentences'] == data['dico'].index(UNK_WORD)).sum()
logger.info("Now %i unknown words covering %.2f%% of the data." % (
unk_count, 100. * unk_count / (len(data['sentences']) - len(data['positions']))
))
# select relevant attributes
assert data['attributes'].size() == (len(data['positions']), len(data['attr_values']) - len(len_attrs))
assert all(x in data['attr_values'] for x in params.attributes)
attr_idx = [sorted(data['attr_values'].keys()).index(x) for x in params.attributes if x != len_attr]
data['attributes'] = data['attributes'][:, attr_idx]
if len_attr is not None:
data['attributes'] = torch.cat([data['attributes'], torch.from_numpy(len_labels[:, None])], 1)
# save data to avoid identical reloading
loaded_data[path] = data
return data