def __init__()

in fairnr/modules/encoder.py [0:0]


    def __init__(self, args, voxel_path=None, bbox_path=None, shared_values=None):
        super().__init__(args)
        # read initial voxels or learned sparse voxels
        self.voxel_path = voxel_path if voxel_path is not None else args.voxel_path
        self.bbox_path = bbox_path if bbox_path is not None else getattr(args, "initial_boundingbox", None)
        assert (self.bbox_path is not None) or (self.voxel_path is not None), \
            "at least initial bounding box or pretrained voxel files are required."
        self.voxel_index = None
        self.scene_scale = getattr(args, "scene_scale", 1.0)

        if self.voxel_path is not None:
            # read voxel file
            assert os.path.exists(self.voxel_path), "voxel file must exist"
            
            if Path(self.voxel_path).suffix == '.ply':
                from plyfile import PlyData, PlyElement
                plyvoxel = PlyData.read(self.voxel_path)
                elements = [x.name for x in plyvoxel.elements]
                
                assert 'vertex' in elements
                plydata = plyvoxel['vertex']
                fine_points = torch.from_numpy(
                    np.stack([plydata['x'], plydata['y'], plydata['z']]).astype('float32').T)

                if 'face' in elements:
                    # read voxel meshes... automatically detect voxel size
                    faces = plyvoxel['face']['vertex_indices']
                    t = fine_points[faces[0].astype('int64')]
                    voxel_size = torch.abs(t[0] - t[1]).max()

                    # indexing voxel vertices
                    fine_points = torch.unique(fine_points, dim=0)

                    # vertex_ids, _ = discretize_points(fine_points, voxel_size)
                    # vertex_ids_offset = vertex_ids + 1
                    
                    # # simple hashing
                    # vertex_ids = vertex_ids[:, 0] * 1000000 + vertex_ids[:, 1] * 1000 + vertex_ids[:, 2]
                    # vertex_ids_offset = vertex_ids_offset[:, 0] * 1000000 + vertex_ids_offset[:, 1] * 1000 + vertex_ids_offset[:, 2]

                    # vertex_ids = {k: True for k in vertex_ids.tolist()}
                    # vertex_inside = [v in vertex_ids for v in vertex_ids_offset.tolist()]
                    
                    # # get voxel centers
                    # fine_points = fine_points[torch.tensor(vertex_inside)] + voxel_size * .5
                    # fine_points = fine_points + voxel_size * .5   --> use all corners as centers
                
                else:
                    # voxel size must be provided
                    assert getattr(args, "voxel_size", None) is not None, "final voxel size is essential."
                    voxel_size = args.voxel_size

                if 'quality' in elements:
                    self.voxel_index = torch.from_numpy(plydata['quality']).long()
               
            else:
                # supporting the old style .txt voxel points
                fine_points = torch.from_numpy(np.loadtxt(self.voxel_path)[:, 3:].astype('float32'))
        else:
            # read bounding-box file
            bbox = np.loadtxt(self.bbox_path)
            voxel_size = bbox[-1] if getattr(args, "voxel_size", None) is None else args.voxel_size
            fine_points = torch.from_numpy(bbox2voxels(bbox[:6], voxel_size))
        
        half_voxel = voxel_size * .5
        
        # transform from voxel centers to voxel corners (key/values)
        fine_coords, _ = discretize_points(fine_points, half_voxel)
        fine_keys0 = offset_points(fine_coords, 1.0).reshape(-1, 3)
        fine_keys, fine_feats = torch.unique(fine_keys0, dim=0, sorted=True, return_inverse=True)
        fine_feats = fine_feats.reshape(-1, 8)
        num_keys = torch.scalar_tensor(fine_keys.size(0)).long()
        
        # ray-marching step size
        if getattr(args, "raymarching_stepsize_ratio", 0) > 0:
            step_size = args.raymarching_stepsize_ratio * voxel_size
        else:
            step_size = args.raymarching_stepsize
        
        # register parameters (will be saved to checkpoints)
        self.register_buffer("points", fine_points)          # voxel centers
        self.register_buffer("keys", fine_keys.long())       # id used to find voxel corners/embeddings
        self.register_buffer("feats", fine_feats.long())     # for each voxel, 8 voxel corner ids
        self.register_buffer("num_keys", num_keys)
        self.register_buffer("keep", fine_feats.new_ones(fine_feats.size(0)).long())  # whether the voxel will be pruned

        self.register_buffer("voxel_size", torch.scalar_tensor(voxel_size))
        self.register_buffer("step_size", torch.scalar_tensor(step_size))
        self.register_buffer("max_hits", torch.scalar_tensor(args.max_hits))

        logger.info("loaded {} voxel centers, {} voxel corners".format(fine_points.size(0), num_keys))

        # set-up other hyperparameters and initialize running time caches
        self.embed_dim = getattr(args, "voxel_embed_dim", None)
        self.deterministic_step = getattr(args, "deterministic_step", False)
        self.use_octree = getattr(args, "use_octree", False)
        self.track_max_probs = getattr(args, "track_max_probs", False)    
        self._runtime_caches = {
            "flatten_centers": None,
            "flatten_children": None,
            "max_voxel_probs": None
        }

        # sparse voxel embeddings     
        if shared_values is None and self.embed_dim > 0:
            self.values = Embedding(num_keys, self.embed_dim, None)
        else:
            self.values = shared_values