def point_vox_moco_collator()

in datasets/collators/point_vox_moco_collator.py [0:0]


def point_vox_moco_collator(batch):
    batch_size = len(batch)
    
    point = [x["data"] for x in batch]
    point_moco = [x["data_moco"] for x in batch]
    vox = [x["vox"] for x in batch]
    vox_moco = [x["vox_moco"] for x in batch]
    # labels are repeated N+1 times but they are the same
    labels = [x["label"][0] for x in batch]
    labels = torch.LongTensor(labels).squeeze()

    # data valid is repeated N+1 times but they are the same
    data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])

    points_moco = torch.stack([point_moco[i][0] for i in range(batch_size)])
    points = torch.stack([point[i][0] for i in range(batch_size)])
    
    vox_moco = collate_fn([vox_moco[i][0] for i in range(batch_size)])
    vox = collate_fn([vox[i][0] for i in range(batch_size)])

    output_batch = {
        "points": points,
        "points_moco": points_moco,
        "vox": vox,
        "vox_moco": vox_moco,
        "label": labels,
        "data_valid": data_valid,
    }

    return output_batch