in tzrec/tools/tdm/gen_tree/tree_cluster.py [0:0]
def _cluster(self, index: npt.NDArray) -> List[npt.NDArray]:
data = self.data[index]
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
labels = kmeans.labels_
sub_indices = []
remain_index = []
ave_num = int(len(index) / self.n_clusters)
for i in range(self.n_clusters):
sub_i = np.where(labels == i)[0]
sub_index = index[sub_i]
if len(sub_index) <= ave_num:
sub_indices.append(sub_index)
else:
distances = kmeans.transform(data[sub_i])[:, i]
sorted_index = sub_index[np.argsort(distances)]
sub_indices.append(sorted_index[:ave_num])
remain_index.extend(list(sorted_index[ave_num:]))
idx = 0
remain_index = np.array(remain_index)
# reblance index
while idx < self.n_clusters and len(remain_index) > 0:
if len(sub_indices[idx]) >= ave_num:
idx += 1
else:
diff = min(len(remain_index), ave_num - len(sub_indices[idx]))
remain_data = self.data[remain_index]
distances = kmeans.transform(remain_data)[:, idx]
sorted_index = remain_index[np.argsort(distances)]
# Supplement the data by sorting the distances
# to the current cluster centers in ascending order.
sub_indices[idx] = np.append(
sub_indices[idx], np.array(sorted_index[0:diff])
)
remain_index = sorted_index[diff:]
idx += 1
if len(remain_index) > 0:
sub_indices[0] = np.append(sub_indices[0], remain_index)
return sub_indices