easy_rec/python/eval.py (83 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import os import six import tensorflow as tf from tensorflow.python.lib.io import file_io from easy_rec.python.main import distribute_evaluate from easy_rec.python.main import evaluate from easy_rec.python.protos.train_pb2 import DistributionStrategy from easy_rec.python.utils import config_util from easy_rec.python.utils import ds_util from easy_rec.python.utils import estimator_utils from easy_rec.python.utils.distribution_utils import set_tf_config_and_get_distribute_eval_worker_num_on_ds # NOQA if tf.__version__ >= '2.0': tf = tf.compat.v1 logging.basicConfig( format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) tf.app.flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config ' 'file.') tf.app.flags.DEFINE_string( 'checkpoint_path', None, 'checkpoint to be evaled ' ' if not specified, use the latest checkpoint in ' 'train_config.model_dir') tf.app.flags.DEFINE_multi_string( 'eval_input_path', None, 'eval data path, if specified will ' 'override pipeline_config.eval_input_path') tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir') tf.app.flags.DEFINE_string('odps_config', None, help='odps config path') tf.app.flags.DEFINE_string('eval_result_path', 'eval_result.txt', 'eval result metric file') tf.app.flags.DEFINE_bool('distribute_eval', False, 'use distribute parameter server for train and eval.') tf.app.flags.DEFINE_bool('is_on_ds', False, help='is on ds') FLAGS = tf.app.flags.FLAGS def main(argv): if FLAGS.odps_config: os.environ['ODPS_CONFIG_FILE_PATH'] = FLAGS.odps_config if FLAGS.is_on_ds: ds_util.set_on_ds() if FLAGS.distribute_eval: set_tf_config_and_get_distribute_eval_worker_num_on_ds() assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.' if FLAGS.model_dir: pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config') if file_io.file_exists(pipeline_config_path): logging.info('update pipeline_config_path to %s' % pipeline_config_path) else: pipeline_config_path = FLAGS.pipeline_config_path else: pipeline_config_path = FLAGS.pipeline_config_path pipeline_config = config_util.get_configs_from_pipeline_file( pipeline_config_path) if FLAGS.model_dir: pipeline_config.model_dir = FLAGS.model_dir if pipeline_config.train_config.train_distribute in [ DistributionStrategy.HorovodStrategy, ]: estimator_utils.init_hvd() elif pipeline_config.train_config.train_distribute in [ DistributionStrategy.EmbeddingParallelStrategy, DistributionStrategy.SokStrategy ]: estimator_utils.init_hvd() estimator_utils.init_sok() if FLAGS.distribute_eval: os.environ['distribute_eval'] = 'True' eval_result = distribute_evaluate(pipeline_config, FLAGS.checkpoint_path, FLAGS.eval_input_path, FLAGS.eval_result_path) else: os.environ['distribute_eval'] = 'False' eval_result = evaluate(pipeline_config, FLAGS.checkpoint_path, FLAGS.eval_input_path, FLAGS.eval_result_path) if eval_result is not None: # when distribute evaluate, only master has eval_result. for key in sorted(eval_result): # skip logging binary data if isinstance(eval_result[key], six.binary_type): continue logging.info('%s: %s' % (key, str(eval_result[key]))) else: logging.info('Eval result in master worker.') if __name__ == '__main__': tf.app.run()