in utils.py [0:0]
def batch_dialogs(dials):
"""
Batch dialogs
:param dials:
:return:
"""
# print(type(dials))
# assert type(dials) == list
# assert type(dials[0]) == list
dial_length = [len(s) for s in dials]
word_length = [[len(sent) for sent in dial] for dial in dials]
wmax = max([c for p in word_length for c in p])
mat = torch.zeros((len(dials), max(dial_length), wmax)).long()
wlen = torch.zeros((len(dials), max(dial_length)))
for i, dial in enumerate(dials):
dial_end = dial_length[i]
for j, sent in enumerate(dial):
sent_end = word_length[i][j]
mat[i, j, :sent_end] = torch.LongTensor(sent[:sent_end])
wlen[i][j] = sent_end
return mat, np.array(dial_length), wlen