in easy_rec/python/export.py [0:0]
def main(argv):
extra_params = {}
if FLAGS.redis_url:
extra_params['redis_url'] = FLAGS.redis_url
if FLAGS.redis_passwd:
extra_params['redis_passwd'] = FLAGS.redis_passwd
if FLAGS.redis_threads > 0:
extra_params['redis_threads'] = FLAGS.redis_threads
if FLAGS.redis_batch_size > 0:
extra_params['redis_batch_size'] = FLAGS.redis_batch_size
if FLAGS.redis_expire > 0:
extra_params['redis_expire'] = FLAGS.redis_expire
if FLAGS.redis_embedding_version:
extra_params['redis_embedding_version'] = FLAGS.redis_embedding_version
if FLAGS.redis_write_kv == 0:
extra_params['redis_write_kv'] = False
else:
extra_params['redis_write_kv'] = True
assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
if FLAGS.model_dir:
pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
if file_io.file_exists(pipeline_config_path):
logging.info('update pipeline_config_path to %s' % pipeline_config_path)
else:
pipeline_config_path = FLAGS.pipeline_config_path
else:
pipeline_config_path = FLAGS.pipeline_config_path
if FLAGS.oss_path:
extra_params['oss_path'] = FLAGS.oss_path
if FLAGS.oss_endpoint:
extra_params['oss_endpoint'] = FLAGS.oss_endpoint
if FLAGS.oss_ak:
extra_params['oss_ak'] = FLAGS.oss_ak
if FLAGS.oss_sk:
extra_params['oss_sk'] = FLAGS.oss_sk
if FLAGS.oss_timeout > 0:
extra_params['oss_timeout'] = FLAGS.oss_timeout
if FLAGS.oss_expire > 0:
extra_params['oss_expire'] = FLAGS.oss_expire
if FLAGS.oss_threads > 0:
extra_params['oss_threads'] = FLAGS.oss_threads
if FLAGS.oss_write_kv:
extra_params['oss_write_kv'] = True if FLAGS.oss_write_kv == 1 else False
if FLAGS.oss_embedding_version:
extra_params['oss_embedding_version'] = FLAGS.oss_embedding_version
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path)
if pipeline_config.train_config.train_distribute in [
DistributionStrategy.HorovodStrategy,
]:
estimator_utils.init_hvd()
elif pipeline_config.train_config.train_distribute in [
DistributionStrategy.EmbeddingParallelStrategy,
DistributionStrategy.SokStrategy
]:
estimator_utils.init_hvd()
estimator_utils.init_sok()
if FLAGS.clear_export:
logging.info('will clear export_dir=%s' % FLAGS.export_dir)
if gfile.IsDirectory(FLAGS.export_dir):
gfile.DeleteRecursively(FLAGS.export_dir)
export_out_dir = export(FLAGS.export_dir, pipeline_config_path,
FLAGS.checkpoint_path, FLAGS.asset_files,
FLAGS.verbose, **extra_params)
if FLAGS.export_done_file:
flag_file = os.path.join(export_out_dir, FLAGS.export_done_file)
logging.info('create export done file: %s' % flag_file)
with gfile.GFile(flag_file, 'w') as fout:
fout.write('ExportDone')