def train()

in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]


    def train(self, save_tree: bool = False) -> TDMTreeNode:
        """Cluster data."""
        self._read()
        # The (code, index) stored in the queue represent the node number
        # in the current class’s tree and the index of the item belonging
        # to this class, respectively.
        queue = mp.Queue()
        queue.put((0, np.arange(len(self.leaf_nodes))))
        processes = []
        pipes = []
        for _ in range(self.parallel):
            parent_conn, child_conn = mp.Pipe()
            p = mp.Process(target=self._train, args=(child_conn, queue))
            processes.append(p)
            pipes.append(parent_conn)
            p.start()

        self.codes = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
        for pipe in pipes:
            codes = pipe.recv()
            for i in range(len(codes)):
                if codes[i] > 0:
                    self.leaf_nodes[i].tree_code = codes[i]

        for p in processes:
            p.join()

        assert queue.empty()
        builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
        root = builder.build(self.leaf_nodes, save_tree)
        return root