def get_collate_fn()

in ttw/utils.py [0:0]


def get_collate_fn(cuda=True):
    def _collate_fn(data):
        batch = dict()
        for k in data[0].keys():
            k_data = [data[i][k] for i in range(len(data))]
            if k in ['textrecog', 'landmarks']:
                batch[k], _ = list_to_tensor(k_data)
            if k in ['goldstandard', 'actions']:
                batch[k], batch[k+'_mask'] = list_to_tensor(k_data)
            if k  == 'utterance':
                batch['utterance'], batch['utterance_mask'] = list_to_tensor(k_data)
            if k in ['target']:
                batch[k] = torch.LongTensor(k_data)
            if k in ['resnet', 'weight']:
                batch[k] = torch.FloatTensor(k_data)
            if k == 'fasttext':
                batch[k], _ = list_to_tensor(k_data, tensor_type=torch.FloatTensor)
        return to_variable(batch, cuda=cuda)
    return _collate_fn