pai_jobs/run.py (505 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from __future__ import print_function import logging # use few threads to avoid oss error import os import time import tensorflow as tf import yaml from tensorflow.python.platform import gfile import easy_rec from easy_rec.python.inference.odps_predictor import ODPSPredictor from easy_rec.python.inference.vector_retrieve import VectorRetrieve from easy_rec.python.tools.pre_check import run_check from easy_rec.python.utils import config_util from easy_rec.python.utils import constant from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import fg_util from easy_rec.python.utils import hpo_util from easy_rec.python.utils import pai_util from easy_rec.python.utils.distribution_utils import DistributionStrategyMap from easy_rec.python.utils.distribution_utils import set_distribution_config os.environ['IS_ON_PAI'] = '1' from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_train_worker_num # NOQA os.environ['OENV_MultiWriteThreadsNum'] = '4' os.environ['OENV_MultiCopyThreadsNum'] = '4' if not tf.__version__.startswith('1.12'): tf = tf.compat.v1 try: import tensorflow_io as tfio # noqa: F401 except Exception as ex: logging.error('failed to import tfio: %s' % str(ex)) tf.disable_eager_execution() from easy_rec.python.main import _train_and_evaluate_impl as train_and_evaluate_impl # NOQA logging.basicConfig( level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s') tf.app.flags.DEFINE_string('worker_hosts', '', 'Comma-separated list of hostname:port pairs') tf.app.flags.DEFINE_string('ps_hosts', '', 'Comma-separated list of hostname:port pairs') tf.app.flags.DEFINE_string('job_name', '', 'task type, ps/worker') tf.app.flags.DEFINE_integer('task_index', 0, 'Index of task within the job') tf.app.flags.DEFINE_string('config', '', 'EasyRec config file path') tf.app.flags.DEFINE_string('cmd', 'train', 'command type, train/evaluate/export') tf.app.flags.DEFINE_string('tables', '', 'tables passed by pai command') # flags for train tf.app.flags.DEFINE_integer('num_gpus_per_worker', 1, 'number of gpu to use in training') tf.app.flags.DEFINE_boolean('with_evaluator', False, 'whether a evaluator is necessary') tf.app.flags.DEFINE_string( 'eval_method', 'none', 'default to none, choices are [none: not evaluate,' + 'master: evaluate on master, separate: evaluate on a separate task]') tf.app.flags.DEFINE_string('distribute_strategy', '', 'training distribute strategy') tf.app.flags.DEFINE_string('edit_config_json', '', 'edit config json string') tf.app.flags.DEFINE_string('train_tables', '', 'tables used for train') tf.app.flags.DEFINE_string('eval_tables', '', 'tables used for evaluation') tf.app.flags.DEFINE_string('boundary_table', '', 'tables used for boundary') tf.app.flags.DEFINE_string('sampler_table', '', 'tables used for sampler') tf.app.flags.DEFINE_string('fine_tune_checkpoint', None, 'finetune checkpoint path') tf.app.flags.DEFINE_string('query_table', '', 'table used for retrieve vector neighbours') tf.app.flags.DEFINE_string('doc_table', '', 'table used for be retrieved as indexed vectors') tf.app.flags.DEFINE_enum('knn_distance', 'inner_product', ['l2', 'inner_product'], 'type of knn distance') tf.app.flags.DEFINE_integer('knn_num_neighbours', None, 'top n neighbours to be retrieved') tf.app.flags.DEFINE_integer('knn_feature_dims', None, 'number of feature dimensions') tf.app.flags.DEFINE_enum( 'knn_index_type', 'ivfflat', ['flat', 'ivfflat', 'ivfpq', 'gpu_flat', 'gpu_ivfflat', 'gpu_ivfpg'], 'knn index type') tf.app.flags.DEFINE_string('knn_feature_delimiter', ',', 'delimiter for feature vectors') tf.app.flags.DEFINE_integer('knn_nlist', 5, 'number of split part on each worker') tf.app.flags.DEFINE_integer('knn_nprobe', 2, 'number of probe part on each worker') tf.app.flags.DEFINE_integer( 'knn_compress_dim', 8, 'number of dimensions after compress for `ivfpq` and `gpu_ivfpq`') # flags used for evaluate & export tf.app.flags.DEFINE_string( 'checkpoint_path', '', 'checkpoint to be evaluated or exported ' 'if not specified, use the latest checkpoint ' 'in train_config.model_dir') # flags used for evaluate tf.app.flags.DEFINE_string('eval_result_path', 'eval_result.txt', 'eval result metric file') tf.app.flags.DEFINE_bool('distribute_eval', False, 'use distribute parameter server for train and eval.') # flags used for export tf.app.flags.DEFINE_string('export_dir', '', 'directory where model should be exported to') tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists') tf.app.flags.DEFINE_string('export_done_file', '', 'a flag file to signal that export model is done') tf.app.flags.DEFINE_integer('max_wait_ckpt_ts', 0, 'max wait time in seconds for checkpoints') tf.app.flags.DEFINE_boolean('continue_train', True, 'use the same model to continue train or not') # flags used for predict tf.app.flags.DEFINE_string('saved_model_dir', '', 'directory where saved_model.pb exists') tf.app.flags.DEFINE_string('outputs', '', 'output tables') tf.app.flags.DEFINE_string( 'all_cols', '', 'union of (selected_cols, reserved_cols), separated with , ') tf.app.flags.DEFINE_string( 'all_col_types', '', 'column data types, for build record defaults, separated with ,') tf.app.flags.DEFINE_string( 'selected_cols', '', 'columns to keep from input table, they are separated with ,') tf.app.flags.DEFINE_string( 'reserved_cols', '', 'columns to keep from input table, they are separated with ,') tf.app.flags.DEFINE_string( 'output_cols', None, 'output columns, such as: score float. multiple columns are separated by ,') tf.app.flags.DEFINE_integer('batch_size', 1024, 'predict batch size') tf.app.flags.DEFINE_string( 'profiling_file', None, 'time stat file which can be viewed using chrome tracing') tf.app.flags.DEFINE_string('redis_url', None, 'export to redis url, host:port') tf.app.flags.DEFINE_string('redis_passwd', None, 'export to redis passwd') tf.app.flags.DEFINE_integer('redis_threads', 5, 'export to redis threads') tf.app.flags.DEFINE_integer('redis_batch_size', 1024, 'export to redis batch_size') tf.app.flags.DEFINE_integer('redis_timeout', 600, 'export to redis time_out in seconds') tf.app.flags.DEFINE_integer('redis_expire', 24, 'export to redis expire time in hour') tf.app.flags.DEFINE_string('redis_embedding_version', '', 'redis embedding version') tf.app.flags.DEFINE_integer('redis_write_kv', 1, 'whether write kv ') tf.app.flags.DEFINE_string( 'oss_path', None, 'write embed objects to oss folder, oss://bucket/folder') tf.app.flags.DEFINE_string('oss_endpoint', None, 'oss endpoint') tf.app.flags.DEFINE_string('oss_ak', None, 'oss ak') tf.app.flags.DEFINE_string('oss_sk', None, 'oss sk') tf.app.flags.DEFINE_integer('oss_threads', 10, '# threads access oss at the same time') tf.app.flags.DEFINE_integer('oss_timeout', 10, 'connect to oss, time_out in seconds') tf.app.flags.DEFINE_integer('oss_expire', 24, 'oss expire time in hours') tf.app.flags.DEFINE_integer('oss_write_kv', 1, 'whether to write embedding to oss') tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version') tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information') tf.app.flags.DEFINE_bool('place_embedding_on_cpu', False, 'whether to place embedding variables on cpu') # for automl hyper parameter tuning tf.app.flags.DEFINE_string('model_dir', None, 'model directory') tf.app.flags.DEFINE_bool('clear_model', False, 'remove model directory if exists') tf.app.flags.DEFINE_string('hpo_param_path', None, 'hyperparameter tuning param path') tf.app.flags.DEFINE_string('hpo_metric_save_path', None, 'hyperparameter save metric path') tf.app.flags.DEFINE_string('asset_files', None, 'extra files to add to export') tf.app.flags.DEFINE_bool('check_mode', False, 'is use check mode') tf.app.flags.DEFINE_string('fg_json_path', None, '') tf.app.flags.DEFINE_bool('enable_avx_str_split', False, 'enable avx str split to speedup') FLAGS = tf.app.flags.FLAGS def check_param(name): assert getattr(FLAGS, name) != '', '%s should not be empty' % name def set_selected_cols(pipeline_config, selected_cols, all_cols, all_col_types): if selected_cols: pipeline_config.data_config.selected_cols = selected_cols # add column types which will be used by OdpsInput, OdpsInputV2 # to check consistency with input_fields.input_type if all_cols: all_cols_arr = all_cols.split(',') all_col_types_arr = all_col_types.split(',') all_col_types_map = { x.strip(): y.strip() for x, y in zip(all_cols_arr, all_col_types_arr) } selected_cols_arr = [x.strip() for x in selected_cols.split(',')] selected_col_types = [all_col_types_map[x] for x in selected_cols_arr] selected_col_types = ','.join(selected_col_types) pipeline_config.data_config.selected_col_types = selected_col_types print('[run.py] data_config.selected_cols = "%s"' % pipeline_config.data_config.selected_cols) print('[run.py] data_config.selected_col_types = "%s"' % pipeline_config.data_config.selected_col_types) def _wait_ckpt(ckpt_path, max_wait_ts): logging.info('will wait %s seconds for checkpoint' % max_wait_ts) start_ts = time.time() if '/model.ckpt-' not in ckpt_path: while time.time() - start_ts < max_wait_ts: tmp_ckpt = estimator_utils.latest_checkpoint(ckpt_path) if tmp_ckpt is None: logging.info('wait for checkpoint in directory[%s]' % ckpt_path) time.sleep(30) else: logging.info('find checkpoint[%s] in directory[%s]' % (tmp_ckpt, ckpt_path)) break else: while time.time() - start_ts < max_wait_ts: if not gfile.Exists(ckpt_path + '.index'): logging.info('wait for checkpoint[%s]' % ckpt_path) time.sleep(30) else: logging.info('find checkpoint[%s]' % ckpt_path) break def main(argv): pai_util.set_on_pai() if FLAGS.enable_avx_str_split: constant.enable_avx_str_split() logging.info('will enable avx str split: %s' % constant.is_avx_str_split_enabled()) if FLAGS.distribute_eval: os.environ['distribute_eval'] = 'True' # load lookup op try: lookup_op_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so') tf.load_op_library(lookup_op_path) except Exception as ex: print('Error: exception: %s' % str(ex)) num_gpus_per_worker = FLAGS.num_gpus_per_worker worker_hosts = FLAGS.worker_hosts.split(',') num_worker = len(worker_hosts) assert FLAGS.distribute_strategy in DistributionStrategyMap, \ 'invalid distribute_strategy [%s], available ones are %s' % ( FLAGS.distribute_strategy, ','.join(DistributionStrategyMap.keys())) if FLAGS.config: config = pai_util.process_config(FLAGS.config, FLAGS.task_index, len(FLAGS.worker_hosts.split(','))) pipeline_config = config_util.get_configs_from_pipeline_file(config, False) # should be in front of edit_config_json step # otherwise data_config and feature_config are not ready if pipeline_config.fg_json_path: fg_util.load_fg_json_to_config(pipeline_config) if FLAGS.edit_config_json: print('[run.py] edit_config_json = %s' % FLAGS.edit_config_json) config_json = yaml.safe_load(FLAGS.edit_config_json) config_util.edit_config(pipeline_config, config_json) if FLAGS.model_dir: pipeline_config.model_dir = FLAGS.model_dir pipeline_config.model_dir = pipeline_config.model_dir.strip() print('[run.py] update model_dir to %s' % pipeline_config.model_dir) assert pipeline_config.model_dir.startswith( 'oss://'), 'invalid model_dir format: %s' % pipeline_config.model_dir if FLAGS.asset_files: pipeline_config.export_config.asset_files.extend( FLAGS.asset_files.split(',')) if FLAGS.config: if not pipeline_config.model_dir.endswith('/'): pipeline_config.model_dir += '/' if FLAGS.clear_model: if gfile.IsDirectory( pipeline_config.model_dir) and estimator_utils.is_chief(): gfile.DeleteRecursively(pipeline_config.model_dir) if FLAGS.max_wait_ckpt_ts > 0: if FLAGS.checkpoint_path: _wait_ckpt(FLAGS.checkpoint_path, FLAGS.max_wait_ckpt_ts) else: _wait_ckpt(pipeline_config.model_dir, FLAGS.max_wait_ckpt_ts) if FLAGS.cmd == 'train': assert FLAGS.config, 'config should not be empty when training!' if not FLAGS.train_tables and FLAGS.tables: tables = FLAGS.tables.split(',') assert len( tables ) >= 2, 'at least 2 tables must be specified, but only[%d]: %s' % ( len(tables), FLAGS.tables) if FLAGS.train_tables: pipeline_config.train_input_path = FLAGS.train_tables elif FLAGS.tables: pipeline_config.train_input_path = FLAGS.tables.split(',')[0] if FLAGS.eval_tables: pipeline_config.eval_input_path = FLAGS.eval_tables elif FLAGS.tables: pipeline_config.eval_input_path = FLAGS.tables.split(',')[1] print('[run.py] train_tables: %s' % pipeline_config.train_input_path) print('[run.py] eval_tables: %s' % pipeline_config.eval_input_path) if FLAGS.fine_tune_checkpoint: pipeline_config.train_config.fine_tune_checkpoint = FLAGS.fine_tune_checkpoint if pipeline_config.train_config.HasField('fine_tune_checkpoint'): pipeline_config.train_config.fine_tune_checkpoint = estimator_utils.get_latest_checkpoint_from_checkpoint_path( pipeline_config.train_config.fine_tune_checkpoint, False) if FLAGS.boundary_table: logging.info('Load boundary_table: %s' % FLAGS.boundary_table) config_util.add_boundaries_to_config(pipeline_config, FLAGS.boundary_table) if FLAGS.sampler_table: pipeline_config.data_config.negative_sampler.input_path = FLAGS.sampler_table if FLAGS.train_tables or FLAGS.tables: # parse selected_cols set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols, FLAGS.all_col_types) else: pipeline_config.data_config.selected_cols = '' pipeline_config.data_config.selected_col_types = '' distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy] # update params specified by automl if hpo_param_path is specified if FLAGS.hpo_param_path: logging.info('hpo_param_path = %s' % FLAGS.hpo_param_path) with gfile.GFile(FLAGS.hpo_param_path, 'r') as fin: hpo_config = yaml.safe_load(fin) hpo_params = hpo_config['param'] config_util.edit_config(pipeline_config, hpo_params) config_util.auto_expand_share_feature_configs(pipeline_config) print('[run.py] with_evaluator %s' % str(FLAGS.with_evaluator)) print('[run.py] eval_method %s' % FLAGS.eval_method) assert FLAGS.eval_method in [ 'none', 'master', 'separate' ], 'invalid evalaute_method: %s' % FLAGS.eval_method # with_evaluator is depreciated, keeped for compatibility if FLAGS.with_evaluator: FLAGS.eval_method = 'separate' num_worker = set_tf_config_and_get_train_worker_num( FLAGS.ps_hosts, FLAGS.worker_hosts, FLAGS.task_index, FLAGS.job_name, distribute_strategy=distribute_strategy, eval_method=FLAGS.eval_method) set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker, distribute_strategy) logging.info('run.py check_mode: %s .' % FLAGS.check_mode) train_and_evaluate_impl( pipeline_config, continue_train=FLAGS.continue_train, check_mode=FLAGS.check_mode) if FLAGS.hpo_metric_save_path: hpo_util.save_eval_metrics( pipeline_config.model_dir, metric_save_path=FLAGS.hpo_metric_save_path, has_evaluator=(FLAGS.eval_method == 'separate')) elif FLAGS.cmd == 'evaluate': check_param('config') # TODO: support multi-worker evaluation if not FLAGS.distribute_eval: assert len( FLAGS.worker_hosts.split(',')) == 1, 'evaluate only need 1 worker' config_util.auto_expand_share_feature_configs(pipeline_config) if FLAGS.eval_tables: pipeline_config.eval_input_path = FLAGS.eval_tables elif FLAGS.tables: pipeline_config.eval_input_path = FLAGS.tables.split(',')[0] distribute_strategy = DistributionStrategyMap[FLAGS.distribute_strategy] set_tf_config_and_get_train_worker_num( FLAGS.ps_hosts, FLAGS.worker_hosts, FLAGS.task_index, FLAGS.job_name, eval_method='none') set_distribution_config(pipeline_config, num_worker, num_gpus_per_worker, distribute_strategy) if FLAGS.eval_tables or FLAGS.tables: # parse selected_cols set_selected_cols(pipeline_config, FLAGS.selected_cols, FLAGS.all_cols, FLAGS.all_col_types) else: pipeline_config.data_config.selected_cols = '' pipeline_config.data_config.selected_col_types = '' if FLAGS.distribute_eval: os.environ['distribute_eval'] = 'True' logging.info('will_use_distribute_eval') distribute_eval = os.environ.get('distribute_eval') logging.info('distribute_eval = {}'.format(distribute_eval)) easy_rec.distribute_evaluate(pipeline_config, FLAGS.checkpoint_path, None, FLAGS.eval_result_path) else: os.environ['distribute_eval'] = 'False' logging.info('will_use_eval') distribute_eval = os.environ.get('distribute_eval') logging.info('distribute_eval = {}'.format(distribute_eval)) easy_rec.evaluate(pipeline_config, FLAGS.checkpoint_path, None, FLAGS.eval_result_path) elif FLAGS.cmd == 'export': check_param('export_dir') check_param('config') if FLAGS.place_embedding_on_cpu: os.environ['place_embedding_on_cpu'] = 'True' else: os.environ['place_embedding_on_cpu'] = 'False' redis_params = {} if FLAGS.redis_url: redis_params['redis_url'] = FLAGS.redis_url if FLAGS.redis_passwd: redis_params['redis_passwd'] = FLAGS.redis_passwd if FLAGS.redis_threads > 0: redis_params['redis_threads'] = FLAGS.redis_threads if FLAGS.redis_batch_size > 0: redis_params['redis_batch_size'] = FLAGS.redis_batch_size if FLAGS.redis_expire > 0: redis_params['redis_expire'] = FLAGS.redis_expire if FLAGS.redis_embedding_version: redis_params['redis_embedding_version'] = FLAGS.redis_embedding_version if FLAGS.redis_write_kv: redis_params['redis_write_kv'] = FLAGS.redis_write_kv oss_params = {} if FLAGS.oss_path: oss_params['oss_path'] = FLAGS.oss_path if FLAGS.oss_endpoint: oss_params['oss_endpoint'] = FLAGS.oss_endpoint if FLAGS.oss_ak: oss_params['oss_ak'] = FLAGS.oss_ak if FLAGS.oss_sk: oss_params['oss_sk'] = FLAGS.oss_sk if FLAGS.oss_timeout > 0: oss_params['oss_timeout'] = FLAGS.oss_timeout if FLAGS.oss_expire > 0: oss_params['oss_expire'] = FLAGS.oss_expire if FLAGS.oss_threads > 0: oss_params['oss_threads'] = FLAGS.oss_threads if FLAGS.oss_embedding_version: redis_params['oss_embedding_version'] = FLAGS.oss_embedding_version if FLAGS.oss_write_kv: oss_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False set_tf_config_and_get_train_worker_num( FLAGS.ps_hosts, FLAGS.worker_hosts, FLAGS.task_index, FLAGS.job_name, eval_method='none') assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker' config_util.auto_expand_share_feature_configs(pipeline_config) export_dir = FLAGS.export_dir if not export_dir.endswith('/'): export_dir = export_dir + '/' if FLAGS.clear_export: if gfile.IsDirectory(export_dir): gfile.DeleteRecursively(export_dir) extra_params = redis_params extra_params.update(oss_params) export_out_dir = easy_rec.export(export_dir, pipeline_config, FLAGS.checkpoint_path, FLAGS.asset_files, FLAGS.verbose, **extra_params) if FLAGS.export_done_file: flag_file = os.path.join(export_out_dir, FLAGS.export_done_file) logging.info('create export done file: %s' % flag_file) with gfile.GFile(flag_file, 'w') as fout: fout.write('ExportDone') elif FLAGS.cmd == 'predict': check_param('tables') check_param('saved_model_dir') logging.info('will use the following columns as model input: %s' % FLAGS.selected_cols) logging.info('will copy the following columns to output: %s' % FLAGS.reserved_cols) profiling_file = FLAGS.profiling_file if FLAGS.task_index == 0 else None if profiling_file is not None: print('profiling_file = %s ' % profiling_file) predictor = ODPSPredictor( FLAGS.saved_model_dir, fg_json_path=FLAGS.fg_json_path, profiling_file=profiling_file, all_cols=FLAGS.all_cols, all_col_types=FLAGS.all_col_types) input_table, output_table = FLAGS.tables, FLAGS.outputs logging.info('input_table = %s, output_table = %s' % (input_table, output_table)) worker_num = len(FLAGS.worker_hosts.split(',')) predictor.predict_impl( input_table, output_table, reserved_cols=FLAGS.reserved_cols, output_cols=FLAGS.output_cols, batch_size=FLAGS.batch_size, slice_id=FLAGS.task_index, slice_num=worker_num) elif FLAGS.cmd == 'export_checkpoint': check_param('export_dir') check_param('config') set_tf_config_and_get_train_worker_num( FLAGS.ps_hosts, FLAGS.worker_hosts, FLAGS.task_index, FLAGS.job_name, eval_method='none') assert len(FLAGS.worker_hosts.split(',')) == 1, 'export only need 1 woker' config_util.auto_expand_share_feature_configs(pipeline_config) easy_rec.export_checkpoint( pipeline_config, export_path=FLAGS.export_dir + '/model', checkpoint_path=FLAGS.checkpoint_path, asset_files=FLAGS.asset_files, verbose=FLAGS.verbose) elif FLAGS.cmd == 'vector_retrieve': check_param('knn_distance') assert FLAGS.knn_feature_dims is not None, '`knn_feature_dims` should not be None' assert FLAGS.knn_num_neighbours is not None, '`knn_num_neighbours` should not be None' query_table, doc_table, output_table = FLAGS.query_table, FLAGS.doc_table, FLAGS.outputs if not query_table: tables = FLAGS.tables.split(',') assert len( tables ) >= 1, 'at least 1 tables must be specified, but only[%d]: %s' % ( len(tables), FLAGS.tables) query_table = tables[0] doc_table = tables[1] if len(tables) > 1 else query_table knn = VectorRetrieve( query_table, doc_table, output_table, ndim=FLAGS.knn_feature_dims, distance=1 if FLAGS.knn_distance == 'inner_product' else 0, delimiter=FLAGS.knn_feature_delimiter, batch_size=FLAGS.batch_size, index_type=FLAGS.knn_index_type, nlist=FLAGS.knn_nlist, nprobe=FLAGS.knn_nprobe, m=FLAGS.knn_compress_dim) worker_hosts = FLAGS.worker_hosts.split(',') knn(FLAGS.knn_num_neighbours, FLAGS.task_index, len(worker_hosts)) elif FLAGS.cmd == 'check': run_check(pipeline_config, FLAGS.tables) else: raise ValueError( 'cmd should be one of train/evaluate/export/predict/export_checkpoint/vector_retrieve' ) if __name__ == '__main__': tf.app.run()