tzrec/predict.py (91 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from tzrec.main import predict if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--scripted_model_path", type=str, default=None, help="scripted model to be evaled, if not specified, use the checkpoint", ) parser.add_argument( "--predict_input_path", type=str, default=None, help="inference data input path", ) parser.add_argument( "--predict_output_path", type=str, default=None, help="inference data output path", ) parser.add_argument( "--reserved_columns", type=str, default=None, help="column names to reserved in output", ) parser.add_argument( "--output_columns", type=str, default=None, help="column names of model output", ) parser.add_argument( "--batch_size", type=int, default=None, help="predict batch size, default will use batch_size in config.", ) parser.add_argument( "--predict_threads", type=int, default=None, help="predict threads num, default will use num_workers in data_config.", ) parser.add_argument( "--is_profiling", action="store_true", default=False, help="profiling predict progress.", ) parser.add_argument( "--debug_level", type=int, default=0, help="debug level for debug parsed inputs etc.", ) parser.add_argument( "--dataset_type", type=str, default=None, help="dataset type, default will use dataset type in data_config.", ) parser.add_argument( "--writer_type", type=str, default=None, help="data writer type, default will be same as dataset_type in data_config.", ) parser.add_argument( "--edit_config_json", type=str, default=None, help='edit pipeline config str, example: {"data_config.fg_mode":"FG_DAG"}', ) args, extra_args = parser.parse_known_args() predict( predict_input_path=args.predict_input_path, predict_output_path=args.predict_output_path, scripted_model_path=args.scripted_model_path, reserved_columns=args.reserved_columns, output_columns=args.output_columns, batch_size=args.batch_size, is_profiling=args.is_profiling, debug_level=args.debug_level, dataset_type=args.dataset_type, predict_threads=args.predict_threads, writer_type=args.writer_type, edit_config_json=args.edit_config_json, )