def _train()

in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]


    def _train(self, pipe: Connection, queue: mp.Queue) -> None:
        last_size = -1
        catch_time = 0
        processed = False
        code = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
        parent_code = None
        index = None
        while True:
            for _ in range(3):
                try:
                    parent_code, index = queue.get(timeout=self.timeout)
                except Exception as _:
                    index = None
                if index is not None:
                    break

            if index is None:
                if processed and (last_size <= self.mini_batch or catch_time >= 3):
                    logger.info("Process {} exits".format(os.getpid()))
                    break
                else:
                    logger.info(
                        "Got empty job, pid: {}, time: {}".format(
                            os.getpid(), catch_time
                        )
                    )
                    catch_time += 1
                    continue

            processed = True
            catch_time = 0
            last_size = len(index)
            if last_size <= self.mini_batch:
                self._mini_batch(parent_code, index, code)
            else:
                start = time.time()
                sub_index = self._cluster(index)
                logger.info(
                    "Train iteration done, parent_code:{}, "
                    "data size: {}, elapsed time: {}".format(
                        parent_code, len(index), time.time() - start
                    )
                )
                self.timeout = int(0.4 * self.timeout + 0.6 * (time.time() - start))
                if self.timeout < 5:
                    self.timeout = 5

                for i in range(self.n_clusters):
                    if len(sub_index[i]) > 1:
                        queue.put((self.n_clusters * parent_code + i + 1, sub_index[i]))

        process_count = 0
        for c in code:
            if c > 0:
                process_count += 1
        logger.info("Process {} process {} items".format(os.getpid(), process_count))
        pipe.send(code)