tzrec/tools/tdm/gen_tree/tree_builder.py (137 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import pickle from collections import Counter from typing import List, Optional, Union import pyarrow as pa import pyarrow.compute as pc from anytree import NodeMixin from anytree.exporter.dictexporter import DictExporter from tzrec.utils.env_util import use_hash_node_id from tzrec.utils.logging_util import logger class BaseClass(object): """Tree base class.""" pass class TDMTreeNode(BaseClass, NodeMixin): """TDM tree node.""" def __init__( self, tree_code: int = -1, item_id: Optional[Union[int, str]] = None, cate: Optional[str] = None, attrs: Optional[List[pa.Scalar]] = None, raw_attrs: Optional[List[pa.Scalar]] = None, parent: Optional["TDMTreeNode"] = None, children: Optional[List["TDMTreeNode"]] = None, ) -> None: super(TDMTreeNode, self).__init__() self.tree_code = tree_code self.item_id = item_id self.cate = cate self.attrs = attrs or [] self.raw_attrs = raw_attrs or [] self.attrs_list = [] self.raw_attrs_list = [] self.parent = parent if children: self.children = children def set_parent(self, parent: "TDMTreeNode") -> None: """Set parent.""" self.parent = parent def set_children(self, children: List["TDMTreeNode"]) -> None: """Set children.""" self.children = children class TreeBuilder: """Build tree base codes. Args: output_dir(str): tree output file. n_cluster(int): The branching factor of the nodes in the tree. """ def __init__(self, output_dir: Optional[str] = ".", n_cluster: int = 2) -> None: self.output_dir = output_dir self.n_cluster = n_cluster def build( self, leaf_nodes: List[TDMTreeNode], save_tree: bool = False, ) -> TDMTreeNode: """Build tree.""" # pull all leaf nodes to the last level min_code = ( self.n_cluster ** math.ceil(math.log(len(leaf_nodes), self.n_cluster)) - 1 ) max_code = 0 max_item_id = 0 for i in range(len(leaf_nodes)): while leaf_nodes[i].tree_code < min_code: leaf_nodes[i].tree_code = leaf_nodes[i].tree_code * self.n_cluster + 1 max_code = max(leaf_nodes[i].tree_code, max_code) leaf_item_id = leaf_nodes[i].item_id assert leaf_item_id is not None if not use_hash_node_id(): max_item_id = max(leaf_item_id, max_item_id) tree_nodes: List[Optional[TDMTreeNode]] = [None for _ in range(max_code + 1)] logger.info("start gen code_list") for leaf_node in leaf_nodes: tree_nodes[leaf_node.tree_code] = leaf_node ancestors = self._ancestors(leaf_node.tree_code) for ancestor in ancestors: if tree_nodes[ancestor] is None: tree_nodes[ancestor] = TDMTreeNode(tree_code=ancestor) ancestor_node = tree_nodes[ancestor] assert ancestor_node is not None ancestor_node.attrs_list.append(leaf_node.attrs) ancestor_node.raw_attrs_list.append(leaf_node.raw_attrs) for code in range(max_code + 1): node = tree_nodes[code] if node is None: continue assert node is not None if node.item_id is None: node.attrs = self._column_modes(node.attrs_list) node.raw_attrs = self._column_means(node.raw_attrs_list) node.item_id = ( # pyre-ignore [58] f"nonleaf#{code}" if use_hash_node_id() else max_item_id + code + 1 ) if code > 0: ancestor = int((code - 1) / self.n_cluster) ancestor_node = tree_nodes[ancestor] assert ancestor_node is not None node.set_parent(ancestor_node) node.attrs_list = [] node.raw_attrs_list = [] root_node = tree_nodes[0] assert root_node is not None if save_tree: self.save_tree(root_node) return root_node def _column_modes(self, matrix: List[List[pa.Scalar]]) -> List[pa.Scalar]: transposed_matrix = list(zip(*matrix)) modes = [] for column in transposed_matrix: if pa.types.is_string(column[0].type): filtered_column = [x for x in column if x] if filtered_column: most_common = Counter(filtered_column).most_common(1)[0][0] modes.append(most_common) else: modes.append(pa.scalar("")) else: mode = pc.mode(list(column)) if len(mode) > 0: modes.append(mode[0][0]) else: # null value with column dtype modes.append(column[0]) return modes def _column_means(self, matrix: List[List[pa.Scalar]]) -> List[pa.Scalar]: transposed_matrix = list(zip(*matrix)) means = [] for column in transposed_matrix: mean = pc.mean(list(column)) if pa.types.is_integer(column[0].type): mean = pc.round(mean) means.append(mean.cast(column[0].type, safe=False)) return means def save_tree(self, root: TDMTreeNode) -> None: """Save tree.""" assert self.output_dir is not None, "if save tree, must set output_dir." path = os.path.join(self.output_dir, "tree.pkl") logger.info(f"save tree to {path}") exporter = DictExporter() data = exporter.export(root) with open(path, "wb") as f: pickle.dump(data, f) def _ancestors(self, code: int) -> List[int]: ancs = [] while code > 0: code = int((code - 1) / self.n_cluster) ancs.append(code) return ancs