in ttw/utils.py [0:0]
def get_collate_fn(cuda=True):
def _collate_fn(data):
batch = dict()
for k in data[0].keys():
k_data = [data[i][k] for i in range(len(data))]
if k in ['textrecog', 'landmarks']:
batch[k], _ = list_to_tensor(k_data)
if k in ['goldstandard', 'actions']:
batch[k], batch[k+'_mask'] = list_to_tensor(k_data)
if k == 'utterance':
batch['utterance'], batch['utterance_mask'] = list_to_tensor(k_data)
if k in ['target']:
batch[k] = torch.LongTensor(k_data)
if k in ['resnet', 'weight']:
batch[k] = torch.FloatTensor(k_data)
if k == 'fasttext':
batch[k], _ = list_to_tensor(k_data, tensor_type=torch.FloatTensor)
return to_variable(batch, cuda=cuda)
return _collate_fn