tzrec/train_eval.py (51 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from tzrec.main import train_and_evaluate
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pipeline_config_path",
type=str,
default=None,
help="Path to pipeline config file.",
)
parser.add_argument(
"--model_dir",
type=str,
default=None,
help="will update the model_dir in pipeline_config",
)
parser.add_argument(
"--train_input_path", type=str, default=None, help="train data input path"
)
parser.add_argument(
"--eval_input_path", type=str, default=None, help="eval data input path"
)
parser.add_argument(
"--continue_train",
action="store_true",
default=False,
help="continue train using existing model_dir",
)
parser.add_argument(
"--fine_tune_checkpoint",
type=str,
default=None,
help="will update the train_config.fine_tune_checkpoint in pipeline_config",
)
parser.add_argument(
"--edit_config_json",
type=str,
default=None,
help='edit pipeline config str, example: {"model_dir":"experiments/",'
'"feature_configs[0].raw_feature.boundaries":[4,5,6,7]}',
)
args, extra_args = parser.parse_known_args()
train_and_evaluate(
args.pipeline_config_path,
args.train_input_path,
args.eval_input_path,
args.model_dir,
args.continue_train,
args.fine_tune_checkpoint,
args.edit_config_json,
)