in contactopt/loader.py [0:0]
def collate_fn(batch):
out = dict()
batch_keys = batch[0].keys()
skip_keys = ['obj_faces', 'obj_verts_gt', 'obj_contact_gt', 'obj_normals_aug', 'obj_verts_aug'] # These will be manually collated
# For each not in skip_keys, use default torch collator
for key in [k for k in batch_keys if k not in skip_keys]:
out[key] = torch.utils.data._utils.collate.default_collate([d[key] for d in batch])
verts_gt_all = [sample['obj_verts_gt'] for sample in batch]
verts_aug_all = [sample['obj_verts_aug'] for sample in batch]
faces_all = [sample['obj_faces'] for sample in batch]
contact_all = [sample['obj_contact_gt'] for sample in batch]
obj_normals_aug_all = [sample['obj_normals_aug'] for sample in batch]
out['obj_contact_gt'] = pytorch3d.structures.utils.list_to_padded(contact_all, pad_value=-1)
out['obj_normals_aug'] = pytorch3d.structures.utils.list_to_padded(obj_normals_aug_all, pad_value=-1)
# out['obj_verts_gt'] = pytorch3d.structures.utils.list_to_padded(verts_gt_all, pad_value=-1)
# out['obj_verts_aug'] = pytorch3d.structures.utils.list_to_padded(verts_aug_all, pad_value=-1)
# out['obj_faces'] = pytorch3d.structures.utils.list_to_padded(faces_all, pad_value=-1)
out['mesh_gt'] = Meshes(verts=verts_gt_all, faces=faces_all) # This is slower than the above, but probably fast enough
out['mesh_aug'] = Meshes(verts=verts_aug_all, faces=faces_all)
return out