in shapenet/data/mesh_vox.py [0:0]
def postprocess(self, batch, device=None):
if device is None:
device = torch.device("cuda")
imgs, meshes, points, normals, voxels, Ps, id_strs = batch
imgs = imgs.to(device)
if meshes is not None:
meshes = meshes.to(device)
if points is not None and normals is not None:
points = points.to(device)
normals = normals.to(device)
else:
points, normals = sample_points_from_meshes(
meshes, num_samples=self.num_samples, return_normals=True
)
if voxels is not None:
if torch.is_tensor(voxels):
# We used cached voxels on disk, just cast and return
voxels = voxels.to(device)
else:
# We got a list of voxel_coords, and need to compute voxels on-the-fly
voxel_coords = voxels
Ps = Ps.to(device)
voxels = []
for i, cur_voxel_coords in enumerate(voxel_coords):
cur_voxel_coords = cur_voxel_coords.to(device)
cur_voxels = self._voxelize(cur_voxel_coords, Ps[i])
voxels.append(cur_voxels)
voxels = torch.stack(voxels, dim=0)
if self.return_id_str:
return imgs, meshes, points, normals, voxels, id_strs
else:
return imgs, meshes, points, normals, voxels