def collate_fn()

in contactopt/loader.py [0:0]


    def collate_fn(batch):
        out = dict()
        batch_keys = batch[0].keys()
        skip_keys = ['obj_faces', 'obj_verts_gt', 'obj_contact_gt', 'obj_normals_aug', 'obj_verts_aug']    # These will be manually collated

        # For each not in skip_keys, use default torch collator
        for key in [k for k in batch_keys if k not in skip_keys]:
            out[key] = torch.utils.data._utils.collate.default_collate([d[key] for d in batch])

        verts_gt_all = [sample['obj_verts_gt'] for sample in batch]
        verts_aug_all = [sample['obj_verts_aug'] for sample in batch]
        faces_all = [sample['obj_faces'] for sample in batch]
        contact_all = [sample['obj_contact_gt'] for sample in batch]
        obj_normals_aug_all = [sample['obj_normals_aug'] for sample in batch]

        out['obj_contact_gt'] = pytorch3d.structures.utils.list_to_padded(contact_all, pad_value=-1)
        out['obj_normals_aug'] = pytorch3d.structures.utils.list_to_padded(obj_normals_aug_all, pad_value=-1)

        # out['obj_verts_gt'] = pytorch3d.structures.utils.list_to_padded(verts_gt_all, pad_value=-1)
        # out['obj_verts_aug'] = pytorch3d.structures.utils.list_to_padded(verts_aug_all, pad_value=-1)
        # out['obj_faces'] = pytorch3d.structures.utils.list_to_padded(faces_all, pad_value=-1)
        out['mesh_gt'] = Meshes(verts=verts_gt_all, faces=faces_all)    # This is slower than the above, but probably fast enough
        out['mesh_aug'] = Meshes(verts=verts_aug_all, faces=faces_all)

        return out