tzrec/tools/create_fg_json.py (123 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import copy import json import os import shutil import tempfile from odps import ODPS from tzrec.datasets.odps_dataset import _create_odps_account from tzrec.features.feature import create_fg_json from tzrec.main import _create_features, _get_dataloader from tzrec.utils import config_util from tzrec.utils.logging_util import logger if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--pipeline_config_path", type=str, default=None, help="Path to pipeline config file.", ) parser.add_argument( "--fg_output_dir", type=str, default=None, help="Directory to output feature generator json file.", ) parser.add_argument( "--reserves", type=str, default=None, help="Reserved column names, e.g. label,request_id.", ) parser.add_argument( "--odps_project_name", type=str, default=None, help="odps project name.", ) parser.add_argument( "--fg_resource_name", type=str, default=None, help="fg json resource name. if specified, will upload fg.json to odps.", ) parser.add_argument( "--force_update_resource", action="store_true", default=False, help="if true will update fg.json.", ) parser.add_argument( "--remove_bucketizer", action="store_true", default=False, help="remove bucktizer params in fg json.", ) parser.add_argument( "--debug", action="store_true", default=False, help="debug feature config and fg json or not.", ) args, extra_args = parser.parse_known_args() pipeline_config = config_util.load_pipeline_config(args.pipeline_config_path) features = _create_features( list(pipeline_config.feature_configs), pipeline_config.data_config ) if args.debug: pipeline_config.data_config.num_workers = 1 dataloader = _get_dataloader( pipeline_config.data_config, features, pipeline_config.train_input_path ) iterator = iter(dataloader) _ = next(iterator) tmp_dir = tempfile.mkdtemp(prefix="tzrec_") fg_json = create_fg_json(features, asset_dir=tmp_dir) if args.remove_bucketizer: fg_json = copy.copy(fg_json) for feature in fg_json["features"]: feature.pop("hash_bucket_size") feature.pop("vocab_dict") feature.pop("vocab_list") feature.pop("boundaries") feature.pop("num_buckets") if feature["feature_type"] != "tokenize_feature": feature.pop("vocab_file") if args.reserves is not None: reserves = [] for column in args.reserves.strip().split(","): reserves.append(column.strip()) fg_json["reserves"] = reserves fg_name = args.fg_resource_name if args.fg_resource_name else "fg.json" fg_path = os.path.join(tmp_dir, fg_name) with open(fg_path, "w") as f: json.dump(fg_json, f, indent=4) if args.fg_output_dir: shutil.copytree(tmp_dir, args.fg_output_dir, dirs_exist_ok=True) project = args.odps_project_name fg_resource_name = args.fg_resource_name if project is not None and fg_resource_name is not None: account, odps_endpoint = _create_odps_account() o = ODPS( account=account, project=project, endpoint=odps_endpoint, ) for fname in os.listdir(tmp_dir): fpath = os.path.join(tmp_dir, fname) if o.exist_resource(fname): if args.force_update_resource: o.delete_resource(fname) logger.info( f"{fname} has already existed, will update this resource !" ) resource = o.create_resource( fname, "file", file_obj=open(fpath, "rb") ) else: logger.info(f"uploading resource [{fname}].") resource = o.create_resource(fname, "file", file_obj=open(fpath, "rb")) if tmp_dir is None: if os.path.exists(tmp_dir): shutil.rmtree(tmp_dir)