in E2E_TOD/utils.py [0:0]
def padSeqs(sequences, maxlen=None, truncated = False, pad_method='post',
trunc_method='pre', dtype='int32', value=0.):
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
num_samples = len(sequences)
seq_maxlen = np.max(lengths)
if maxlen is not None and truncated:
maxlen = min(seq_maxlen, maxlen)
else:
maxlen = seq_maxlen
sample_shape = tuple()
for s in sequences:
if len(s) > 0:
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s):
print('empty list/array was found')
continue
if trunc_method == 'pre':
trunc = s[-maxlen:]
elif trunc_method == 'post':
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" not understood' % trunc_method)
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
(trunc.shape[1:], idx, sample_shape))
if pad_method == 'post':
x[idx, :len(trunc)] = trunc
elif pad_method == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % pad_method)
return x