def _cluster()

in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]


    def _cluster(self, index: npt.NDArray) -> List[npt.NDArray]:
        data = self.data[index]
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
        labels = kmeans.labels_
        sub_indices = []
        remain_index = []
        ave_num = int(len(index) / self.n_clusters)

        for i in range(self.n_clusters):
            sub_i = np.where(labels == i)[0]
            sub_index = index[sub_i]
            if len(sub_index) <= ave_num:
                sub_indices.append(sub_index)
            else:
                distances = kmeans.transform(data[sub_i])[:, i]
                sorted_index = sub_index[np.argsort(distances)]
                sub_indices.append(sorted_index[:ave_num])
                remain_index.extend(list(sorted_index[ave_num:]))
        idx = 0
        remain_index = np.array(remain_index)

        # reblance index
        while idx < self.n_clusters and len(remain_index) > 0:
            if len(sub_indices[idx]) >= ave_num:
                idx += 1
            else:
                diff = min(len(remain_index), ave_num - len(sub_indices[idx]))
                remain_data = self.data[remain_index]
                distances = kmeans.transform(remain_data)[:, idx]
                sorted_index = remain_index[np.argsort(distances)]

                # Supplement the data by sorting the distances
                # to the current cluster centers in ascending order.
                sub_indices[idx] = np.append(
                    sub_indices[idx], np.array(sorted_index[0:diff])
                )
                remain_index = sorted_index[diff:]
                idx += 1
        if len(remain_index) > 0:
            sub_indices[0] = np.append(sub_indices[0], remain_index)

        return sub_indices