in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]
def train(self, save_tree: bool = False) -> TDMTreeNode:
"""Cluster data."""
self._read()
# The (code, index) stored in the queue represent the node number
# in the current class’s tree and the index of the item belonging
# to this class, respectively.
queue = mp.Queue()
queue.put((0, np.arange(len(self.leaf_nodes))))
processes = []
pipes = []
for _ in range(self.parallel):
parent_conn, child_conn = mp.Pipe()
p = mp.Process(target=self._train, args=(child_conn, queue))
processes.append(p)
pipes.append(parent_conn)
p.start()
self.codes = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
for pipe in pipes:
codes = pipe.recv()
for i in range(len(codes)):
if codes[i] > 0:
self.leaf_nodes[i].tree_code = codes[i]
for p in processes:
p.join()
assert queue.empty()
builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
root = builder.build(self.leaf_nodes, save_tree)
return root