in fairnr/modules/encoder.py [0:0]
def export_voxels(self, return_mesh=False):
logger.info("exporting learned sparse voxels...")
voxel_idx = torch.arange(self.keep.size(0), device=self.keep.device)
voxel_idx = voxel_idx[self.keep.bool()]
voxel_pts = self.points[self.keep.bool()]
if not return_mesh:
# HACK: we export the original voxel indices as "quality" in case for editing
points = [
(voxel_pts[k, 0], voxel_pts[k, 1], voxel_pts[k, 2], voxel_idx[k])
for k in range(voxel_idx.size(0))
]
vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('quality', 'f4')])
return PlyData([PlyElement.describe(vertex, 'vertex')])
else:
# generate polygon for voxels
center_coords, residual = discretize_points(voxel_pts, self.voxel_size / 2)
offsets = torch.tensor([[-1,-1,-1],[-1,-1,1],[-1,1,-1],[1,-1,-1],[1,1,-1],[1,-1,1],[-1,1,1],[1,1,1]], device=center_coords.device)
vertex_coords = center_coords[:, None, :] + offsets[None, :, :]
vertex_points = vertex_coords.type_as(residual) * self.voxel_size / 2 + residual
faceidxs = [[1,6,7,5],[7,6,2,4],[5,7,4,3],[1,0,2,6],[1,5,3,0],[0,3,4,2]]
all_vertex_keys, all_vertex_idxs = {}, []
for i in range(vertex_coords.shape[0]):
for j in range(8):
key = " ".join(["{}".format(int(p)) for p in vertex_coords[i,j]])
if key not in all_vertex_keys:
all_vertex_keys[key] = vertex_points[i,j]
all_vertex_idxs += [key]
all_vertex_dicts = {key: u for u, key in enumerate(all_vertex_idxs)}
all_faces = torch.stack([torch.stack([vertex_coords[:, k] for k in f]) for f in faceidxs]).permute(2,0,1,3).reshape(-1,4,3)
all_faces_keys = {}
for l in range(all_faces.size(0)):
key = " ".join(["{}".format(int(p)) for p in all_faces[l].sum(0) // 4])
if key not in all_faces_keys:
all_faces_keys[key] = all_faces[l]
vertex = np.array([tuple(all_vertex_keys[key].cpu().tolist()) for key in all_vertex_idxs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
face = np.array([([all_vertex_dicts["{} {} {}".format(*b)] for b in a.cpu().tolist()],) for a in all_faces_keys.values()],
dtype=[('vertex_indices', 'i4', (4,))])
return PlyData([PlyElement.describe(vertex, 'vertex'), PlyElement.describe(face, 'face')])