tzrec/tools/tdm/gen_tree/tree_search_util.py (247 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pickle from collections import OrderedDict from typing import Any, Callable, Iterator, List, Optional, Tuple import pyarrow as pa from anytree.importer.dictimporter import DictImporter from anytree.iterators.abstractiter import AbstractIter from anytree.walker import Walker from tzrec.datasets.dataset import create_writer from tzrec.tools.tdm.gen_tree.tree_builder import TDMTreeNode from tzrec.utils.env_util import use_hash_node_id from tzrec.utils.logging_util import logger class LevelOrderIter(AbstractIter): """Level-order traversal tree.""" @staticmethod def _iter( children: List[TDMTreeNode], filter_: Callable[[TDMTreeNode], bool], stop: Callable[[TDMTreeNode], bool], maxlevel: int, ) -> Iterator[Tuple[TDMTreeNode, int]]: level = 1 while children: next_children = [] for child in children: if filter_(child): yield child, level next_children += AbstractIter._get_children(child.children, stop) children = next_children level += 1 if AbstractIter._abort_at_level(level, maxlevel): break def _add_suffix_to_odps_table(table_path: str, suffix: str) -> str: str_list = table_path.split("/") str_list[4] = str_list[4] + suffix return "/".join(str_list) class TreeSearch(object): """Convert anytree to nodes and edges. Args: output_file (str): nodes and edges output file. tree_path (str): tree file path. root (TDMTreeNode): root node of tree. chile_num (int): The branching factor of the nodes in the tree. """ def __init__( self, output_file: str, tree_path: Optional[str] = None, root: Optional[TDMTreeNode] = None, child_num: int = 2, **kwargs: Any, ) -> None: self.child_num = child_num if root is not None: self.root = root elif tree_path is not None: self._load(tree_path) else: raise ValueError("Either root or tree_path must be provided.") assert self.root is not None, "Either root or tree_path must be provided." self.travel_list = [] self.level_code = [[]] self.max_level = 0 self.output_file = output_file self.dataset_kwargs = {} if "odps_data_quota_name" in kwargs: self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] self._get_nodes() def _load(self, path: str) -> None: """Load tree.""" logger.info("Begin load tree.") with open(path, "rb") as f: data = pickle.load(f) self.root = DictImporter().import_(data) def _get_nodes(self) -> None: """Get node info.""" logger.info("Begin iter tree.") for node, level in LevelOrderIter(self.root): if level - 1 > self.max_level: self.max_level = level - 1 self.level_code.append([]) self.level_code[self.max_level].append(node) logger.info( f"Tree Level: {self.max_level + 1}, Tree Cluster: {self.child_num}." ) tree_walker = Walker() logger.info("Begin Travel Tree.") for leaf_node in self.level_code[-1]: paths_0, paths_1, _ = tree_walker.walk(leaf_node, self.root) paths = list(paths_0) + [paths_1] travel = [i.item_id for i in paths] self.travel_list.append(travel) def save(self, attr_delimiter: str = ",") -> None: """Save tree info.""" if self.output_file.startswith("odps://"): output_path = _add_suffix_to_odps_table(self.output_file, "_node_table") node_writer = create_writer(output_path, **self.dataset_kwargs) ids = [] weight = [] features = [] first_node = True for level, nodes in enumerate(self.level_code): for node in nodes: fea = [level, node.item_id] if node.attrs: fea.append( attr_delimiter.join( map(lambda x: str(x) if x.is_valid else "", node.attrs) ) ) if node.raw_attrs: fea.append( attr_delimiter.join( map( lambda x: str(x) if x.is_valid else 0, node.raw_attrs, ) ) ) # add a node with id -1 for graph-learn to get root node if first_node: ids.append("-1" if use_hash_node_id() else -1) weight.append(1.0) features.append(",".join(["-1"] + list(map(str, fea[1:])))) first_node = False ids.append(node.item_id) weight.append(1.0) features.append(",".join(map(str, fea))) node_table_dict = OrderedDict() node_table_dict["id"] = pa.array(ids) node_table_dict["weight"] = pa.array(weight) node_table_dict["features"] = pa.array(features) node_writer.write(node_table_dict) node_writer.close() output_path = _add_suffix_to_odps_table(self.output_file, "_edge_table") edge_writer = create_writer(output_path, **self.dataset_kwargs) src_ids = [] dst_ids = [] weight = [] for travel in self.travel_list: # do not include edge from leaf to root for i in range(self.max_level - 1): src_ids.append(travel[0]) dst_ids.append(travel[i + 1]) weight.append(1.0) edge_table_dict = OrderedDict() edge_table_dict["src_id"] = pa.array(src_ids) edge_table_dict["dst_id"] = pa.array(dst_ids) edge_table_dict["weight"] = pa.array(weight) edge_writer.write(edge_table_dict) edge_writer.close() else: if not os.path.exists(self.output_file): os.makedirs(self.output_file) with open(os.path.join(self.output_file, "node_table.txt"), "w") as f: id_type = "string" if use_hash_node_id() else "int64" f.write(f"id:{id_type}\tweight:float\tfeature:string\n") first_node = True for level, nodes in enumerate(self.level_code): for node in nodes: fea = [level, node.item_id] if node.attrs: fea.append( attr_delimiter.join( map( lambda x: str(x) if x.is_valid else "", node.attrs, ) ) ) if node.raw_attrs: fea.append( attr_delimiter.join( map( lambda x: str(x) if x.is_valid else 0, node.raw_attrs, ) ) ) # add a node with id -1 for graph-learn to get root node if first_node: f.write(f"-1\t1.0\t-1,{','.join(map(str, fea[1:]))}\n") first_node = False f.write(f"{node.item_id}\t1.0\t{','.join(map(str, fea))}\n") with open(os.path.join(self.output_file, "edge_table.txt"), "w") as f: id_type = "string" if use_hash_node_id() else "int64" f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n") for travel in self.travel_list: # do not include edge from leaf to root for i in range(self.max_level - 1): f.write(f"{travel[0]}\t{travel[i + 1]}\t{1.0}\n") def save_predict_edge(self) -> None: """Save edge info for prediction.""" if self.output_file.startswith("odps://"): output_path = _add_suffix_to_odps_table( self.output_file, "_predict_edge_table" ) writer = create_writer(output_path, **self.dataset_kwargs) # add a edge from -1 to root for graph-learn to get root node src_ids = ["-1" if use_hash_node_id() else -1] dst_ids = [self.root.item_id] weight = [1.0] for i in range(self.max_level): for node in self.level_code[i]: for child in node.children: src_ids.append(node.item_id) dst_ids.append(child.item_id) weight.append(1.0) edge_table_dict = OrderedDict() edge_table_dict["src_id"] = pa.array(src_ids) edge_table_dict["dst_id"] = pa.array(dst_ids) edge_table_dict["weight"] = pa.array(weight) writer.write(edge_table_dict) writer.close() else: with open( os.path.join(self.output_file, "predict_edge_table.txt"), "w" ) as f: id_type = "string" if use_hash_node_id() else "int64" f.write(f"src_id:{id_type}\tdst_id:{id_type}\tweight:float\n") # add a edge from with id -1 to root for graph-learn to get root node f.write(f"-1\t{self.root.item_id}\t1.0\n") for i in range(self.max_level): for node in self.level_code[i]: for child in node.children: f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n") def save_node_feature( self, attr_fields: Optional[str] = None, raw_attr_fields: Optional[str] = None ) -> None: """Save feature of tree node for serving.""" if self.output_file.startswith("odps://"): output_path = _add_suffix_to_odps_table(self.output_file, "_node_feature") writer_type = "OdpsWriter" else: output_path = os.path.join(self.output_file, "node_feature") writer_type = "ParquetWriter" writer = create_writer( output_path, writer_type=writer_type, **self.dataset_kwargs ) attr_field_names = ( [x.strip() for x in attr_fields.split(",")] if attr_fields else [] ) raw_attr_field_names = ( [x.strip() for x in raw_attr_fields.split(",")] if raw_attr_fields else [] ) attr_names = ["item_id"] + attr_field_names + raw_attr_field_names attr_values = [[] for _ in range(len(attr_names))] for _, nodes in enumerate(self.level_code): for node in nodes: for i, attr_value in enumerate( [pa.scalar(node.item_id)] + node.attrs + node.raw_attrs ): attr_values[i].append(attr_value) attr_dict = OrderedDict( zip(attr_names, [pa.array(attr_arr) for attr_arr in attr_values]) ) writer.write(attr_dict) writer.close() def save_serving_tree(self, tree_output_dir: str) -> None: """Save tree info for serving.""" if not os.path.exists(tree_output_dir): os.makedirs(tree_output_dir) with open(os.path.join(tree_output_dir, "serving_tree"), "w") as f: f.write(f"{self.max_level + 1} {self.child_num}\n") for _, nodes in enumerate(self.level_code): for node in nodes: f.write(f"{node.tree_code} {node.item_id}\n")