tzrec/tools/tdm/init_tree.py (91 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from tzrec.tools.tdm.gen_tree.tree_generator import TreeGenerator from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch from tzrec.utils.logging_util import logger if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--item_input_path", type=str, default=None, help="The file path where the item information is stored.", ) parser.add_argument( "--item_id_field", type=str, default=None, help="The column name representing item_id in the file.", ) parser.add_argument( "--cate_id_field", type=str, default=None, help="The column name representing the category of item in the file.", ) parser.add_argument( "--attr_fields", type=str, default=None, help="The column names representing the non-raw features of item in the file.", ) parser.add_argument( "--raw_attr_fields", type=str, default=None, help="The column names representing the raw features of item in the file.", ) parser.add_argument( "--attr_delimiter", type=str, default=",", help="The attribute delimiter in tdm node and edge table.", ) parser.add_argument( "--tree_output_dir", type=str, default=None, help="The tree output directory.", ) parser.add_argument( "--node_edge_output_file", type=str, default=None, help="The nodes and edges table output file.", ) parser.add_argument( "--n_cluster", type=int, default=2, help="The branching factor of the nodes in the tree.", ) parser.add_argument( "--odps_data_quota_name", type=str, default="pay-as-you-go", help="maxcompute storage api/tunnel data quota name.", ) args, extra_args = parser.parse_known_args() generator = TreeGenerator( item_input_path=args.item_input_path, item_id_field=args.item_id_field, cate_id_field=args.cate_id_field, attr_fields=args.attr_fields, raw_attr_fields=args.raw_attr_fields, tree_output_dir=args.tree_output_dir, n_cluster=args.n_cluster, odps_data_quota_name=args.odps_data_quota_name, ) root = generator.generate() logger.info("Tree init done. Start save nodes and edges table.") tree_search = TreeSearch( output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save(attr_delimiter=args.attr_delimiter) tree_search.save_predict_edge() tree_search.save_node_feature(args.attr_fields, args.raw_attr_fields) if args.tree_output_dir: tree_search.save_serving_tree(args.tree_output_dir) logger.info("Save nodes and edges table done.")