tzrec/tools/tdm/gen_tree/tree_cluster.py (230 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import multiprocessing as mp
import os
import time
from multiprocessing.connection import Connection
from typing import Any, List, Optional
import numpy as np
import numpy.typing as npt
import pyarrow as pa
from sklearn.cluster import KMeans
from tzrec.datasets.dataset import create_reader
from tzrec.tools.tdm.gen_tree import tree_builder
from tzrec.tools.tdm.gen_tree.tree_builder import TDMTreeNode
from tzrec.utils.env_util import use_hash_node_id
from tzrec.utils.logging_util import logger
class TreeCluster:
"""Cluster based on emb vec.
Args:
item_input_path (str): The file path where the item information is stored.
item_id_field (str): The column name representing item_id in the file.
attr_fields (List[str]): The column names representing the features in the file.
output_dir (str): The output file.
parallel (int): The number of CPU cores for parallel processing.
n_cluster (int): The branching factor of the nodes in the tree.
"""
def __init__(
self,
item_input_path: str,
item_id_field: str,
attr_fields: Optional[str] = None,
raw_attr_fields: Optional[str] = None,
output_dir: Optional[str] = None,
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.mini_batch = 1024
self.data = None
self.leaf_nodes = None
self.parallel = parallel
self.queue = None
self.timeout = 5
self.codes = None
self.output_dir = output_dir
self.n_clusters = n_cluster
self.item_id_field = item_id_field
self.attr_fields = []
self.raw_attr_fields = []
if attr_fields:
self.attr_fields = [x.strip() for x in attr_fields.split(",")]
if raw_attr_fields:
self.raw_attr_fields = [x.strip() for x in raw_attr_fields.split(",")]
self.embedding_field = embedding_field
self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]
def _read(self) -> None:
t1 = time.time()
data = list()
self.leaf_nodes = []
selected_cols = (
{self.item_id_field, self.embedding_field}
| set(self.attr_fields)
| set(self.raw_attr_fields)
)
reader = create_reader(
self.item_input_path,
4096,
selected_cols=list(selected_cols),
**self.dataset_kwargs,
)
for data_dict in reader.to_batches():
if use_hash_node_id():
ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist()
else:
ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
data += data_dict[self.embedding_field].to_pylist()
batch_tree_nodes = []
for one_id in ids:
batch_tree_nodes.append(TDMTreeNode(item_id=one_id))
for attr in self.attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].attrs.append(attr_data[i])
for attr in self.raw_attr_fields:
attr_data = data_dict[attr]
for i in range(len(batch_tree_nodes)):
batch_tree_nodes[i].raw_attrs.append(attr_data[i])
self.leaf_nodes.extend(batch_tree_nodes)
if isinstance(data[0], str):
data = [eval(i) for i in data]
self.data = np.array(data)
t2 = time.time()
logger.info(
"Read data done, {} records read, elapsed: {}".format(
len(self.leaf_nodes), t2 - t1
)
)
def train(self, save_tree: bool = False) -> TDMTreeNode:
"""Cluster data."""
self._read()
# The (code, index) stored in the queue represent the node number
# in the current class’s tree and the index of the item belonging
# to this class, respectively.
queue = mp.Queue()
queue.put((0, np.arange(len(self.leaf_nodes))))
processes = []
pipes = []
for _ in range(self.parallel):
parent_conn, child_conn = mp.Pipe()
p = mp.Process(target=self._train, args=(child_conn, queue))
processes.append(p)
pipes.append(parent_conn)
p.start()
self.codes = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
for pipe in pipes:
codes = pipe.recv()
for i in range(len(codes)):
if codes[i] > 0:
self.leaf_nodes[i].tree_code = codes[i]
for p in processes:
p.join()
assert queue.empty()
builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
root = builder.build(self.leaf_nodes, save_tree)
return root
# pyre-ignore [24]
def _train(self, pipe: Connection, queue: mp.Queue) -> None:
last_size = -1
catch_time = 0
processed = False
code = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
parent_code = None
index = None
while True:
for _ in range(3):
try:
parent_code, index = queue.get(timeout=self.timeout)
except Exception as _:
index = None
if index is not None:
break
if index is None:
if processed and (last_size <= self.mini_batch or catch_time >= 3):
logger.info("Process {} exits".format(os.getpid()))
break
else:
logger.info(
"Got empty job, pid: {}, time: {}".format(
os.getpid(), catch_time
)
)
catch_time += 1
continue
processed = True
catch_time = 0
last_size = len(index)
if last_size <= self.mini_batch:
self._mini_batch(parent_code, index, code)
else:
start = time.time()
sub_index = self._cluster(index)
logger.info(
"Train iteration done, parent_code:{}, "
"data size: {}, elapsed time: {}".format(
parent_code, len(index), time.time() - start
)
)
self.timeout = int(0.4 * self.timeout + 0.6 * (time.time() - start))
if self.timeout < 5:
self.timeout = 5
for i in range(self.n_clusters):
if len(sub_index[i]) > 1:
queue.put((self.n_clusters * parent_code + i + 1, sub_index[i]))
process_count = 0
for c in code:
if c > 0:
process_count += 1
logger.info("Process {} process {} items".format(os.getpid(), process_count))
pipe.send(code)
def _mini_batch(
self, parent_code: int, index: npt.NDArray, code: npt.NDArray
) -> None:
dq = collections.deque()
dq.append((parent_code, index))
batch_size = len(index)
tstart = time.time()
while dq:
parent_code, index = dq.popleft()
if len(index) <= self.n_clusters:
for i in range(len(index)):
code[index[i]] = self.n_clusters * parent_code + i + 1
continue
sub_index = self._cluster(index)
for i in range(self.n_clusters):
if len(sub_index[i]) > 1:
dq.append((self.n_clusters * parent_code + i + 1, sub_index[i]))
elif len(sub_index[i]) > 0:
for j in range(len(sub_index[i])):
code[sub_index[i][j]] = (
self.n_clusters * parent_code + i + j + 1
)
logger.info(
"Minibatch, batch size: {}, elapsed: {}".format(
batch_size, time.time() - tstart
)
)
def _cluster(self, index: npt.NDArray) -> List[npt.NDArray]:
data = self.data[index]
kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
labels = kmeans.labels_
sub_indices = []
remain_index = []
ave_num = int(len(index) / self.n_clusters)
for i in range(self.n_clusters):
sub_i = np.where(labels == i)[0]
sub_index = index[sub_i]
if len(sub_index) <= ave_num:
sub_indices.append(sub_index)
else:
distances = kmeans.transform(data[sub_i])[:, i]
sorted_index = sub_index[np.argsort(distances)]
sub_indices.append(sorted_index[:ave_num])
remain_index.extend(list(sorted_index[ave_num:]))
idx = 0
remain_index = np.array(remain_index)
# reblance index
while idx < self.n_clusters and len(remain_index) > 0:
if len(sub_indices[idx]) >= ave_num:
idx += 1
else:
diff = min(len(remain_index), ave_num - len(sub_indices[idx]))
remain_data = self.data[remain_index]
distances = kmeans.transform(remain_data)[:, idx]
sorted_index = remain_index[np.argsort(distances)]
# Supplement the data by sorting the distances
# to the current cluster centers in ascending order.
sub_indices[idx] = np.append(
sub_indices[idx], np.array(sorted_index[0:diff])
)
remain_index = sorted_index[diff:]
idx += 1
if len(remain_index) > 0:
sub_indices[0] = np.append(sub_indices[0], remain_index)
return sub_indices