easy_rec/python/export.py (125 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import os import tensorflow as tf from tensorflow.python.lib.io import file_io from tensorflow.python.platform import gfile from easy_rec.python.main import export from easy_rec.python.protos.train_pb2 import DistributionStrategy from easy_rec.python.utils import config_util from easy_rec.python.utils import estimator_utils if tf.__version__ >= '2.0': tf = tf.compat.v1 logging.basicConfig( format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) tf.app.flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config ' 'file.') tf.app.flags.DEFINE_string('checkpoint_path', '', 'checkpoint to be exported') tf.app.flags.DEFINE_string('export_dir', None, 'directory where model should be exported to') tf.app.flags.DEFINE_string('redis_url', None, 'export to redis url, host:port') tf.app.flags.DEFINE_string('redis_passwd', None, 'export to redis passwd') tf.app.flags.DEFINE_integer('redis_threads', 0, 'export to redis threads') tf.app.flags.DEFINE_integer('redis_batch_size', 256, 'export to redis batch_size') tf.app.flags.DEFINE_integer('redis_timeout', 600, 'export to redis time_out in seconds') tf.app.flags.DEFINE_integer('redis_expire', 24, 'export to redis expire time in hour') tf.app.flags.DEFINE_string('redis_embedding_version', '', 'redis embedding version') tf.app.flags.DEFINE_integer('redis_write_kv', 1, 'whether to write embedding to redis') tf.app.flags.DEFINE_string( 'oss_path', None, 'write embed objects to oss folder, oss://bucket/folder') tf.app.flags.DEFINE_string('oss_endpoint', None, 'oss endpoint') tf.app.flags.DEFINE_string('oss_ak', None, 'oss ak') tf.app.flags.DEFINE_string('oss_sk', None, 'oss sk') tf.app.flags.DEFINE_integer('oss_threads', 10, '# threads access oss at the same time') tf.app.flags.DEFINE_integer('oss_timeout', 10, 'connect to oss, time_out in seconds') tf.app.flags.DEFINE_integer('oss_expire', 24, 'oss expire time in hours') tf.app.flags.DEFINE_integer('oss_write_kv', 1, 'whether to write embedding to oss') tf.app.flags.DEFINE_string('oss_embedding_version', '', 'oss embedding version') tf.app.flags.DEFINE_string('asset_files', '', 'more files to add to asset') tf.app.flags.DEFINE_bool('verbose', False, 'print more debug information') tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir') tf.app.flags.mark_flag_as_required('export_dir') tf.app.flags.DEFINE_bool('clear_export', False, 'remove export_dir if exists') tf.app.flags.DEFINE_string('export_done_file', '', 'a flag file to signal that export model is done') FLAGS = tf.app.flags.FLAGS def main(argv): extra_params = {} if FLAGS.redis_url: extra_params['redis_url'] = FLAGS.redis_url if FLAGS.redis_passwd: extra_params['redis_passwd'] = FLAGS.redis_passwd if FLAGS.redis_threads > 0: extra_params['redis_threads'] = FLAGS.redis_threads if FLAGS.redis_batch_size > 0: extra_params['redis_batch_size'] = FLAGS.redis_batch_size if FLAGS.redis_expire > 0: extra_params['redis_expire'] = FLAGS.redis_expire if FLAGS.redis_embedding_version: extra_params['redis_embedding_version'] = FLAGS.redis_embedding_version if FLAGS.redis_write_kv == 0: extra_params['redis_write_kv'] = False else: extra_params['redis_write_kv'] = True assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.' if FLAGS.model_dir: pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config') if file_io.file_exists(pipeline_config_path): logging.info('update pipeline_config_path to %s' % pipeline_config_path) else: pipeline_config_path = FLAGS.pipeline_config_path else: pipeline_config_path = FLAGS.pipeline_config_path if FLAGS.oss_path: extra_params['oss_path'] = FLAGS.oss_path if FLAGS.oss_endpoint: extra_params['oss_endpoint'] = FLAGS.oss_endpoint if FLAGS.oss_ak: extra_params['oss_ak'] = FLAGS.oss_ak if FLAGS.oss_sk: extra_params['oss_sk'] = FLAGS.oss_sk if FLAGS.oss_timeout > 0: extra_params['oss_timeout'] = FLAGS.oss_timeout if FLAGS.oss_expire > 0: extra_params['oss_expire'] = FLAGS.oss_expire if FLAGS.oss_threads > 0: extra_params['oss_threads'] = FLAGS.oss_threads if FLAGS.oss_write_kv: extra_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False if FLAGS.oss_embedding_version: extra_params['oss_embedding_version'] = FLAGS.oss_embedding_version pipeline_config = config_util.get_configs_from_pipeline_file( pipeline_config_path) if pipeline_config.train_config.train_distribute in [ DistributionStrategy.HorovodStrategy, ]: estimator_utils.init_hvd() elif pipeline_config.train_config.train_distribute in [ DistributionStrategy.EmbeddingParallelStrategy, DistributionStrategy.SokStrategy ]: estimator_utils.init_hvd() estimator_utils.init_sok() if FLAGS.clear_export: logging.info('will clear export_dir=%s' % FLAGS.export_dir) if gfile.IsDirectory(FLAGS.export_dir): gfile.DeleteRecursively(FLAGS.export_dir) export_out_dir = export(FLAGS.export_dir, pipeline_config_path, FLAGS.checkpoint_path, FLAGS.asset_files, FLAGS.verbose, **extra_params) if FLAGS.export_done_file: flag_file = os.path.join(export_out_dir, FLAGS.export_done_file) logging.info('create export done file: %s' % flag_file) with gfile.GFile(flag_file, 'w') as fout: fout.write('ExportDone') if __name__ == '__main__': tf.app.run()