def collate()

in dataset/co3d_dataset.py [0:0]


    def collate(cls, batch):
        """
        Given a list objects `batch` of class `cls`, collates them into a batched
        representation suitable for processing with deep networks.
        """

        elem = batch[0]

        if isinstance(elem, cls):
            pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
            id_to_idx = defaultdict(list)
            for i, pc_id in enumerate(pointcloud_ids):
                id_to_idx[pc_id].append(i)

            sequence_point_cloud = []
            sequence_point_cloud_idx = -np.ones((len(batch),))
            for i, ind in enumerate(id_to_idx.values()):
                sequence_point_cloud_idx[ind] = i
                sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
            assert (sequence_point_cloud_idx >= 0).all()

            override_fields = {
                "sequence_point_cloud": sequence_point_cloud,
                "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
            }
            # note that the pre-collate value of sequence_point_cloud_idx is unused

            collated = {}
            for f in fields(elem):
                list_values = override_fields.get(
                    f.name, [getattr(d, f.name) for d in batch]
                )
                collated[f.name] = (
                    cls.collate(list_values)
                    if all(l is not None for l in list_values)
                    else None
                )
            return cls(**collated)
        elif isinstance(elem, Pointclouds):
            # TODO: use concatenation
            pointclouds = type(elem)(
                points=[p.points_padded()[0] for p in batch],
                normals=[p.normals_padded()[0] for p in batch],
                features=[p.features_padded()[0] for p in batch],
            )
            return pointclouds
        elif isinstance(elem, CamerasBase):
            # TODO: make a function for it
            # TODO: don't store K; enforce working in NDC space
            return type(elem)(
                R=torch.cat([c.R for c in batch], dim=0),
                T=torch.cat([c.T for c in batch], dim=0),
                K=torch.cat([c.K for c in batch], dim=0)
                if elem.K is not None
                else None,
                focal_length=torch.cat([c.focal_length for c in batch], dim=0),
                principal_point=torch.cat([c.principal_point for c in batch], dim=0),
            )
        else:
            return torch.utils.data._utils.collate.default_collate(batch)