tzrec/tools/tdm/gen_tree/tree_generator.py (92 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Optional import pyarrow as pa from tzrec.datasets.dataset import create_reader from tzrec.tools.tdm.gen_tree.tree_builder import TDMTreeNode, TreeBuilder from tzrec.utils.env_util import use_hash_node_id class TreeGenerator: """Generate tree and train file. Args: item_input_path(str): The file path where the item information is stored. item_id_field(str): The column name representing item_id in the file. cate_id_field(str): The column name representing the category in the file. attr_fields(List[str]): The column names representing the features in the file. tree_output_file(str): The tree output file. n_cluster(int): The branching factor of the nodes in the tree. """ def __init__( self, item_input_path: str, item_id_field: str, cate_id_field: str, attr_fields: Optional[str] = None, raw_attr_fields: Optional[str] = None, tree_output_dir: Optional[str] = None, n_cluster: int = 2, **kwargs: Any, ) -> None: self.item_input_path = item_input_path self.item_id_field = item_id_field self.cate_id_field = cate_id_field self.attr_fields = [] self.raw_attr_fields = [] if attr_fields: self.attr_fields = [x.strip() for x in attr_fields.split(",")] if raw_attr_fields: self.raw_attr_fields = [x.strip() for x in raw_attr_fields.split(",")] self.tree_output_dir = tree_output_dir self.n_cluster = n_cluster self.dataset_kwargs = {} if "odps_data_quota_name" in kwargs: self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] def generate(self, save_tree: bool = False) -> TDMTreeNode: """Generate tree.""" item_fea = self._read() root = self._init_tree(item_fea, save_tree) return root def _read(self) -> List[TDMTreeNode]: leaf_nodes = [] selected_cols = ( {self.item_id_field, self.cate_id_field} | set(self.attr_fields) | set(self.raw_attr_fields) ) reader = create_reader( self.item_input_path, 4096, selected_cols=list(selected_cols), **self.dataset_kwargs, ) for data_dict in reader.to_batches(): if use_hash_node_id(): ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist() else: ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist() cates = ( data_dict[self.cate_id_field] .cast(pa.string()) .fill_null("") .to_pylist() ) batch_tree_nodes = [] for one_id, one_cate in zip(ids, cates): batch_tree_nodes.append(TDMTreeNode(item_id=one_id, cate=one_cate)) for attr in self.attr_fields: attr_data = data_dict[attr] for i in range(len(batch_tree_nodes)): batch_tree_nodes[i].attrs.append(attr_data[i]) for attr in self.raw_attr_fields: attr_data = data_dict[attr] for i in range(len(batch_tree_nodes)): batch_tree_nodes[i].raw_attrs.append(attr_data[i]) leaf_nodes.extend(batch_tree_nodes) return leaf_nodes def _init_tree(self, leaf_nodes: List[TDMTreeNode], save_tree: bool) -> TDMTreeNode: leaf_nodes.sort(key=lambda x: (x.cate, x.item_id)) def gen_code( start: int, end: int, code: int, leaf_nodes: List[TDMTreeNode] ) -> None: if end <= start: return if end == start + 1: leaf_nodes[start].tree_code = code return for i in range(self.n_cluster): left = int(start + i * (end - start) / self.n_cluster) right = int(start + (i + 1) * (end - start) / self.n_cluster) gen_code( left, right, self.n_cluster * code + self.n_cluster - i, leaf_nodes ) gen_code(0, len(leaf_nodes), 0, leaf_nodes) builder = TreeBuilder(self.tree_output_dir, self.n_cluster) root = builder.build(leaf_nodes, save_tree) return root