in dataset/co3d_dataset.py [0:0]
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(l is not None for l in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
# TODO: use concatenation
pointclouds = type(elem)(
points=[p.points_padded()[0] for p in batch],
normals=[p.normals_padded()[0] for p in batch],
features=[p.features_padded()[0] for p in batch],
)
return pointclouds
elif isinstance(elem, CamerasBase):
# TODO: make a function for it
# TODO: don't store K; enforce working in NDC space
return type(elem)(
R=torch.cat([c.R for c in batch], dim=0),
T=torch.cat([c.T for c in batch], dim=0),
K=torch.cat([c.K for c in batch], dim=0)
if elem.K is not None
else None,
focal_length=torch.cat([c.focal_length for c in batch], dim=0),
principal_point=torch.cat([c.principal_point for c in batch], dim=0),
)
else:
return torch.utils.data._utils.collate.default_collate(batch)