def point_vox_moco_collator()

in datasets/collators/point_vox_moco_lidar_collator.py [0:0]


def point_vox_moco_collator(batch):
    batch_size = len(batch)
    
    point = [x["data"] for x in batch]
    point_moco = [x["data_moco"] for x in batch]
    vox = [x["vox"] for x in batch]
    vox_moco = [x["vox_moco"] for x in batch]

    labels = [x["label"][0] for x in batch]
    labels = torch.LongTensor(labels).squeeze()

    # data valid is repeated N+1 times but they are the same
    data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])

    points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)])
    points = torch.stack([point[i][0] for i in range(batch_size)])

    vox_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]}
    counter = 0
    for data in vox:
        temp = data[0]
        voxels_shape = temp["voxels"].shape
        voxel_num_points_shape = temp["voxel_num_points"].shape
        voxel_coords_shape = temp["voxel_coords"].shape
        for key,val in temp.items():
            if key in ['voxels', 'voxel_num_points']:
                if len(vox_data[key]) > 0:
                    vox_data[key] = np.concatenate([vox_data[key], val], axis=0)
                else:
                    vox_data[key] = val
            elif key == 'voxel_coords':
                coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter)
                if len(vox_data[key]) > 0:
                    vox_data[key] = np.concatenate([vox_data[key], coor], axis=0)
                else:
                    vox_data[key] = coor
        counter += 1
        
    vox_moco_data = {"voxels":[], "voxel_coords":[], "voxel_num_points":[]}
    counter = 0
    for data in vox_moco:
        temp = data[0]
        voxels_shape = temp["voxels"].shape
        voxel_num_points_shape = temp["voxel_num_points"].shape
        voxel_coords_shape = temp["voxel_coords"].shape
        for key,val in temp.items():                
            if key in ['voxels', 'voxel_num_points']:
                if len(vox_moco_data[key]) > 0:
                    vox_moco_data[key] = np.concatenate([vox_moco_data[key], val], axis=0)
                else:
                    vox_moco_data[key] = val
            elif key in 'voxel_coords':
                coor = np.pad(val, ((0, 0), (1, 0)), mode='constant', constant_values=counter)

                if len(vox_moco_data[key]) > 0:
                    vox_moco_data[key] = np.concatenate([vox_moco_data[key], coor], axis=0)
                else:
                    vox_moco_data[key] = coor
        counter += 1
        
    vox_data = {k:torch.from_numpy(vox_data[k]) for k in vox_data}
    vox_moco_data = {k:torch.from_numpy(vox_moco_data[k]) for k in vox_moco_data}

    output_batch = {
        "points": points,
        "points_moco": points_moco,
        "vox": vox_data,
        "vox_moco": vox_moco_data,
        "label": labels,
        "data_valid": data_valid,
    }

    return output_batch