easy_rec/python/train_eval.py (176 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import json
import logging
import os
import tensorflow as tf
from tensorflow.python.platform import gfile
from easy_rec.python.main import _train_and_evaluate_impl
from easy_rec.python.protos.train_pb2 import DistributionStrategy
from easy_rec.python.utils import config_util
from easy_rec.python.utils import ds_util
from easy_rec.python.utils import estimator_utils
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import hpo_util
from easy_rec.python.utils.config_util import process_neg_sampler_data_path
from easy_rec.python.utils.config_util import set_eval_input_path
from easy_rec.python.utils.config_util import set_train_input_path
from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num_on_ds # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--pipeline_config_path',
type=str,
default=None,
help='Path to pipeline config file.')
parser.add_argument(
'--continue_train',
action='store_true',
default=False,
help='continue train using existing model_dir')
parser.add_argument(
'--hpo_param_path',
type=str,
default=None,
help='hyperparam tuning param path')
parser.add_argument(
'--hpo_metric_save_path',
type=str,
default=None,
help='hyperparameter save metric path')
parser.add_argument(
'--model_dir',
type=str,
default=None,
help='will update the model_dir in pipeline_config')
parser.add_argument(
'--train_input_path',
type=str,
nargs='*',
default=None,
help='train data input path')
parser.add_argument(
'--eval_input_path',
type=str,
nargs='*',
default=None,
help='eval data input path')
parser.add_argument(
'--fit_on_eval',
action='store_true',
default=False,
help='Fit evaluation data after fitting and evaluating train data')
parser.add_argument(
'--fit_on_eval_steps',
type=int,
default=None,
help='Fit evaluation data steps')
parser.add_argument(
'--fine_tune_checkpoint',
type=str,
default=None,
help='will update the train_config.fine_tune_checkpoint in pipeline_config'
)
parser.add_argument(
'--edit_config_json',
type=str,
default=None,
help='edit pipeline config str, example: {"model_dir":"experiments/",'
'"feature_config.feature[0].boundaries":[4,5,6,7]}')
parser.add_argument(
'--ignore_finetune_ckpt_error',
action='store_true',
default=False,
help='During incremental training, ignore the problem of missing fine_tune_checkpoint files'
)
parser.add_argument(
'--odps_config', type=str, default=None, help='odps config path')
parser.add_argument(
'--is_on_ds', action='store_true', default=False, help='is on ds')
parser.add_argument(
'--check_mode',
action='store_true',
default=False,
help='is use check mode')
parser.add_argument(
'--selected_cols', type=str, default=None, help='select input columns')
parser.add_argument('--gpu', type=str, default=None, help='gpu id')
args, extra_args = parser.parse_known_args()
if args.gpu is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
edit_config_json = {}
if args.edit_config_json:
edit_config_json = json.loads(args.edit_config_json)
if extra_args is not None and len(extra_args) > 0:
config_util.parse_extra_config_param(extra_args, edit_config_json)
if args.pipeline_config_path is not None:
pipeline_config = config_util.get_configs_from_pipeline_file(
args.pipeline_config_path, False)
if args.selected_cols:
pipeline_config.data_config.selected_cols = args.selected_cols
if args.model_dir:
pipeline_config.model_dir = args.model_dir
logging.info('update model_dir to %s' % pipeline_config.model_dir)
if args.train_input_path:
set_train_input_path(pipeline_config, args.train_input_path)
if args.eval_input_path:
set_eval_input_path(pipeline_config, args.eval_input_path)
if args.fine_tune_checkpoint:
ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
if ckpt_path:
pipeline_config.train_config.fine_tune_checkpoint = ckpt_path
if pipeline_config.fg_json_path:
fg_util.load_fg_json_to_config(pipeline_config)
if args.odps_config:
os.environ['ODPS_CONFIG_FILE_PATH'] = args.odps_config
if len(edit_config_json) > 0:
fine_tune_checkpoint = edit_config_json.get('train_config', {}).get(
'fine_tune_checkpoint', None)
if fine_tune_checkpoint:
ckpt_path = estimator_utils.get_latest_checkpoint_from_checkpoint_path(
args.fine_tune_checkpoint, args.ignore_finetune_ckpt_error)
edit_config_json['train_config']['fine_tune_checkpoint'] = ckpt_path
config_util.edit_config(pipeline_config, edit_config_json)
process_neg_sampler_data_path(pipeline_config)
if args.is_on_ds:
ds_util.set_on_ds()
set_tf_config_and_get_train_worker_num_on_ds()
if pipeline_config.train_config.fine_tune_checkpoint:
ds_util.cache_ckpt(pipeline_config)
if pipeline_config.train_config.train_distribute in [
DistributionStrategy.HorovodStrategy,
]:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute in [
DistributionStrategy.EmbeddingParallelStrategy,
DistributionStrategy.SokStrategy
]:
estimator_utils.init_hvd()
estimator_utils.init_sok()
if args.hpo_param_path:
with gfile.GFile(args.hpo_param_path, 'r') as fin:
hpo_config = json.load(fin)
hpo_params = hpo_config['param']
config_util.edit_config(pipeline_config, hpo_params)
config_util.auto_expand_share_feature_configs(pipeline_config)
_train_and_evaluate_impl(pipeline_config, args.continue_train,
args.check_mode)
hpo_util.save_eval_metrics(
pipeline_config.model_dir,
metric_save_path=args.hpo_metric_save_path,
has_evaluator=False)
else:
config_util.auto_expand_share_feature_configs(pipeline_config)
_train_and_evaluate_impl(
pipeline_config,
args.continue_train,
args.check_mode,
fit_on_eval=args.fit_on_eval,
fit_on_eval_steps=args.fit_on_eval_steps)
else:
raise ValueError('pipeline_config_path should not be empty when training!')