tzrec/tools/tdm/cluster_tree.py (98 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from tzrec.tools.tdm.gen_tree.tree_cluster import TreeCluster from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch from tzrec.utils.logging_util import logger if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--item_input_path", type=str, default=None, help="The file path where the item embedding is stored.", ) parser.add_argument( "--item_id_field", type=str, default=None, help="The column name representing item_id in the file.", ) parser.add_argument( "--embedding_field", type=str, default="item_emb", help="The column name representing item embedding in the file.", ) parser.add_argument( "--attr_fields", type=str, default=None, help="The column names representing the non-raw features of item in the file.", ) parser.add_argument( "--raw_attr_fields", type=str, default=None, help="The column names representing the raw features of item in the file.", ) parser.add_argument( "--attr_delimiter", type=str, default=",", help="The attribute delimiter in tdm node and edge table.", ) parser.add_argument( "--tree_output_dir", type=str, default=None, help="The tree output directory.", ) parser.add_argument( "--node_edge_output_file", type=str, default=None, help="The nodes and edges table output file.", ) parser.add_argument( "--parallel", type=int, default=16, help="The number of CPU cores for parallel processing.", ) parser.add_argument( "--n_cluster", type=int, default=2, help="The branching factor of the nodes in the tree.", ) parser.add_argument( "--odps_data_quota_name", type=str, default="pay-as-you-go", help="maxcompute storage api/tunnel data quota name.", ) args, extra_args = parser.parse_known_args() cluster = TreeCluster( item_input_path=args.item_input_path, item_id_field=args.item_id_field, attr_fields=args.attr_fields, raw_attr_fields=args.raw_attr_fields, output_dir=args.tree_output_dir, embedding_field=args.embedding_field, parallel=args.parallel, n_cluster=args.n_cluster, odps_data_quota_name=args.odps_data_quota_name, ) root = cluster.train() logger.info("Tree cluster done. Start save nodes and edges table.") tree_search = TreeSearch( output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save(attr_delimiter=args.attr_delimiter) tree_search.save_predict_edge() tree_search.save_node_feature(args.attr_fields, args.raw_attr_fields) if args.tree_output_dir: tree_search.save_serving_tree(args.tree_output_dir) logger.info("Save nodes and edges table done.")