in fairnr/modules/encoder.py [0:0]
def __init__(self, args, voxel_path=None, bbox_path=None, shared_values=None):
super().__init__(args)
# read initial voxels or learned sparse voxels
self.voxel_path = voxel_path if voxel_path is not None else args.voxel_path
self.bbox_path = bbox_path if bbox_path is not None else getattr(args, "initial_boundingbox", None)
assert (self.bbox_path is not None) or (self.voxel_path is not None), \
"at least initial bounding box or pretrained voxel files are required."
self.voxel_index = None
self.scene_scale = getattr(args, "scene_scale", 1.0)
if self.voxel_path is not None:
# read voxel file
assert os.path.exists(self.voxel_path), "voxel file must exist"
if Path(self.voxel_path).suffix == '.ply':
from plyfile import PlyData, PlyElement
plyvoxel = PlyData.read(self.voxel_path)
elements = [x.name for x in plyvoxel.elements]
assert 'vertex' in elements
plydata = plyvoxel['vertex']
fine_points = torch.from_numpy(
np.stack([plydata['x'], plydata['y'], plydata['z']]).astype('float32').T)
if 'face' in elements:
# read voxel meshes... automatically detect voxel size
faces = plyvoxel['face']['vertex_indices']
t = fine_points[faces[0].astype('int64')]
voxel_size = torch.abs(t[0] - t[1]).max()
# indexing voxel vertices
fine_points = torch.unique(fine_points, dim=0)
# vertex_ids, _ = discretize_points(fine_points, voxel_size)
# vertex_ids_offset = vertex_ids + 1
# # simple hashing
# vertex_ids = vertex_ids[:, 0] * 1000000 + vertex_ids[:, 1] * 1000 + vertex_ids[:, 2]
# vertex_ids_offset = vertex_ids_offset[:, 0] * 1000000 + vertex_ids_offset[:, 1] * 1000 + vertex_ids_offset[:, 2]
# vertex_ids = {k: True for k in vertex_ids.tolist()}
# vertex_inside = [v in vertex_ids for v in vertex_ids_offset.tolist()]
# # get voxel centers
# fine_points = fine_points[torch.tensor(vertex_inside)] + voxel_size * .5
# fine_points = fine_points + voxel_size * .5 --> use all corners as centers
else:
# voxel size must be provided
assert getattr(args, "voxel_size", None) is not None, "final voxel size is essential."
voxel_size = args.voxel_size
if 'quality' in elements:
self.voxel_index = torch.from_numpy(plydata['quality']).long()
else:
# supporting the old style .txt voxel points
fine_points = torch.from_numpy(np.loadtxt(self.voxel_path)[:, 3:].astype('float32'))
else:
# read bounding-box file
bbox = np.loadtxt(self.bbox_path)
voxel_size = bbox[-1] if getattr(args, "voxel_size", None) is None else args.voxel_size
fine_points = torch.from_numpy(bbox2voxels(bbox[:6], voxel_size))
half_voxel = voxel_size * .5
# transform from voxel centers to voxel corners (key/values)
fine_coords, _ = discretize_points(fine_points, half_voxel)
fine_keys0 = offset_points(fine_coords, 1.0).reshape(-1, 3)
fine_keys, fine_feats = torch.unique(fine_keys0, dim=0, sorted=True, return_inverse=True)
fine_feats = fine_feats.reshape(-1, 8)
num_keys = torch.scalar_tensor(fine_keys.size(0)).long()
# ray-marching step size
if getattr(args, "raymarching_stepsize_ratio", 0) > 0:
step_size = args.raymarching_stepsize_ratio * voxel_size
else:
step_size = args.raymarching_stepsize
# register parameters (will be saved to checkpoints)
self.register_buffer("points", fine_points) # voxel centers
self.register_buffer("keys", fine_keys.long()) # id used to find voxel corners/embeddings
self.register_buffer("feats", fine_feats.long()) # for each voxel, 8 voxel corner ids
self.register_buffer("num_keys", num_keys)
self.register_buffer("keep", fine_feats.new_ones(fine_feats.size(0)).long()) # whether the voxel will be pruned
self.register_buffer("voxel_size", torch.scalar_tensor(voxel_size))
self.register_buffer("step_size", torch.scalar_tensor(step_size))
self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits))
logger.info("loaded {} voxel centers, {} voxel corners".format(fine_points.size(0), num_keys))
# set-up other hyperparameters and initialize running time caches
self.embed_dim = getattr(args, "voxel_embed_dim", None)
self.deterministic_step = getattr(args, "deterministic_step", False)
self.use_octree = getattr(args, "use_octree", False)
self.track_max_probs = getattr(args, "track_max_probs", False)
self._runtime_caches = {
"flatten_centers": None,
"flatten_children": None,
"max_voxel_probs": None
}
# sparse voxel embeddings
if shared_values is None and self.embed_dim > 0:
self.values = Embedding(num_keys, self.embed_dim, None)
else:
self.values = shared_values