easy_rec/python/utils/export_big_model.py (550 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json import logging import os import time import numpy as np import tensorflow as tf from google.protobuf import json_format from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops.variables import global_variables from tensorflow.python.platform.gfile import DeleteRecursively from tensorflow.python.platform.gfile import Exists from tensorflow.python.platform.gfile import GFile from tensorflow.python.platform.gfile import Remove from tensorflow.python.saved_model import signature_constants from tensorflow.python.training.device_setter import replica_device_setter from tensorflow.python.training.monitored_session import ChiefSessionCreator from tensorflow.python.training.monitored_session import Scaffold from tensorflow.python.training.saver import export_meta_graph import easy_rec from easy_rec.python.utils import constant from easy_rec.python.utils import estimator_utils from easy_rec.python.utils import io_util from easy_rec.python.utils import proto_util from easy_rec.python.utils.meta_graph_editor import EMBEDDING_INITIALIZERS from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor if tf.__version__ >= '2.0': from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() ConfigProto = config_pb2.ConfigProto GPUOptions = config_pb2.GPUOptions INCR_UPDATE_SIGNATURE_KEY = 'incr_update_sig' def export_big_model(export_dir, pipeline_config, redis_params, serving_input_fn, estimator, checkpoint_path, verbose): for key in redis_params: logging.info('%s: %s' % (key, redis_params[key])) redis_cache_names = [] for feature_config in pipeline_config.feature_configs: if feature_config.is_cache: if feature_config.feature_name: redis_cache_names.append(feature_config.feature_name) else: redis_cache_names.append(feature_config.input_names[0]) logging.info('The list of cache names: %s' % ','.join(redis_cache_names)) write_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so') kv_module = tf.load_op_library(write_kv_lib_path) try: sparse_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libwrite_sparse_kv.so') sparse_kv_module = tf.load_op_library(sparse_kv_lib_path) except Exception as ex: logging.warning('load libwrite_sparse_kv.so failed: %s' % str(ex)) sparse_kv_module = None if not checkpoint_path: checkpoint_path = estimator_utils.latest_checkpoint( pipeline_config.model_dir) logging.info('checkpoint_path = %s' % checkpoint_path) server = None cluster = None if 'TF_CONFIG' in os.environ: # change chief to master tf_config = estimator_utils.chief_to_master() if tf_config['task']['type'] == 'ps': cluster = tf.train.ClusterSpec(tf_config['cluster']) server = tf.train.Server( cluster, job_name='ps', task_index=tf_config['task']['index']) server.join() elif tf_config['task']['type'] == 'master': if 'ps' in tf_config['cluster']: cluster = tf.train.ClusterSpec(tf_config['cluster']) server = tf.train.Server(cluster, job_name='master', task_index=0) server_target = server.target logging.info('server_target = %s' % server_target) serving_input = serving_input_fn() features = serving_input.features inputs = serving_input.receiver_tensors if cluster: logging.info('cluster = ' + str(cluster)) with tf.device( replica_device_setter( worker_device='/job:master/task:0', cluster=cluster)): outputs = estimator._export_model_fn(features, None, None, estimator.params).predictions meta_graph_def = export_meta_graph() redis_embedding_version = redis_params.get('redis_embedding_version', '') if not redis_embedding_version: meta_graph_def.meta_info_def.meta_graph_version =\ str(int(time.time())) else: meta_graph_def.meta_info_def.meta_graph_version = redis_embedding_version logging.info('meta_graph_version = %s' % meta_graph_def.meta_info_def.meta_graph_version) embed_var_parts = {} embed_norm_name = {} embed_spos = {} # pai embedding variable embedding_vars = {} norm_name_to_ids = {} for x in global_variables(): if 'EmbeddingVariable' in str(type(x)): norm_name, part_id = proto_util.get_norm_embed_name(x.name) norm_name_to_ids[norm_name] = 1 tmp_export = x.export() if x.device not in embedding_vars: embedding_vars[x.device] = [(norm_name, tmp_export.keys, tmp_export.values)] else: embedding_vars[x.device].append( (norm_name, tmp_export.keys, tmp_export.values)) elif '/embedding_weights:' in x.name or '/embedding_weights/part_' in x.name: norm_name, part_id = proto_util.get_norm_embed_name(x.name) norm_name_to_ids[norm_name] = 1 embed_norm_name[x] = norm_name if norm_name not in embed_var_parts: embed_var_parts[norm_name] = {part_id: x} else: embed_var_parts[norm_name][part_id] = x for tid, t in enumerate(norm_name_to_ids.keys()): norm_name_to_ids[t] = str(tid) is_cache_from_redis = [ # noqa: F841 proto_util.is_cache_from_redis(x, redis_cache_names) for x in norm_name_to_ids ] for x in embed_norm_name: embed_norm_name[x] = norm_name_to_ids[embed_norm_name[x]] total_num = 0 for norm_name in embed_var_parts: parts = embed_var_parts[norm_name] spos = 0 part_ids = list(parts.keys()) part_ids.sort() total_num += len(part_ids) for part_id in part_ids: embed_spos[parts[part_id]] = spos spos += parts[part_id].get_shape()[0] redis_url = redis_params.get('redis_url', '') redis_passwd = redis_params.get('redis_passwd', '') logging.info('will export to redis: %s %s' % (redis_url, redis_passwd)) if redis_params.get('redis_write_kv', ''): # group embed by devices per_device_vars = {} for x in embed_norm_name: if x.device not in per_device_vars: per_device_vars[x.device] = [x] else: per_device_vars[x.device].append(x) all_write_res = [] for tmp_dev in per_device_vars: tmp_vars = per_device_vars[tmp_dev] with tf.device(tmp_dev): tmp_names = [embed_norm_name[v] for v in tmp_vars] tmp_spos = [np.array(embed_spos[v], dtype=np.int64) for v in tmp_vars] write_kv_res = kv_module.write_kv( tmp_names, tmp_vars, tmp_spos, url=redis_url, password=redis_passwd, timeout=redis_params.get('redis_timeout', 1500), version=meta_graph_def.meta_info_def.meta_graph_version, threads=redis_params.get('redis_threads', 5), batch_size=redis_params.get('redis_batch_size', 32), expire=redis_params.get('redis_expire', 24), verbose=verbose) all_write_res.append(write_kv_res) for tmp_dev in embedding_vars: with tf.device(tmp_dev): tmp_vs = embedding_vars[tmp_dev] tmp_sparse_names = [norm_name_to_ids[x[0]] for x in tmp_vs] tmp_sparse_keys = [x[1] for x in tmp_vs] tmp_sparse_vals = [x[2] for x in tmp_vs] write_sparse_kv_res = sparse_kv_module.write_sparse_kv( tmp_sparse_names, tmp_sparse_vals, tmp_sparse_keys, url=redis_url, password=redis_passwd, timeout=redis_params.get('redis_timeout', 1500), version=meta_graph_def.meta_info_def.meta_graph_version, threads=redis_params.get('redis_threads', 5), batch_size=redis_params.get('redis_batch_size', 32), expire=redis_params.get('redis_expire', 24), verbose=verbose) all_write_res.append(write_sparse_kv_res) session_config = ConfigProto( allow_soft_placement=True, log_device_placement=False) chief_sess_creator = ChiefSessionCreator( master=server.target if server else '', checkpoint_filename_with_path=checkpoint_path, config=session_config) with tf.train.MonitoredSession( session_creator=chief_sess_creator, hooks=None, stop_grace_period_secs=120) as sess: dump_flags = sess.run(all_write_res) logging.info('write embedding to redis succeed: %s' % str(dump_flags)) else: logging.info('will skip write embedding to redis because ' 'redis_write_kv is set to 0.') # delete embedding_weights collections so that it could be re imported tmp_drop = [] for k in meta_graph_def.collection_def: v = meta_graph_def.collection_def[k] if len( v.node_list.value) > 0 and 'embedding_weights' in v.node_list.value[0]: tmp_drop.append(k) for k in tmp_drop: meta_graph_def.collection_def.pop(k) meta_graph_editor = MetaGraphEditor( os.path.join(easy_rec.ops_dir, 'libembed_op.so'), None, redis_url, redis_passwd, redis_timeout=redis_params.get('redis_timeout', 600), redis_cache_names=redis_cache_names, meta_graph_def=meta_graph_def, norm_name_to_ids=norm_name_to_ids, debug_dir=export_dir if verbose else '') meta_graph_editor.edit_graph() tf.reset_default_graph() saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def) graph = tf.get_default_graph() embed_name_to_id_file = os.path.join(export_dir, 'embed_name_to_ids.txt') with GFile(embed_name_to_id_file, 'w') as fout: for tmp_norm_name in norm_name_to_ids: fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name])) ops.add_to_collection( tf.GraphKeys.ASSET_FILEPATHS, tf.constant( embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt')) export_dir = os.path.join(export_dir, meta_graph_def.meta_info_def.meta_graph_version) export_dir = io_util.fix_oss_dir(export_dir) logging.info('export_dir=%s' % export_dir) if Exists(export_dir): logging.info('will delete old dir: %s' % export_dir) DeleteRecursively(export_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir) tensor_info_inputs = {} for tmp_key in inputs: tmp = graph.get_tensor_by_name(inputs[tmp_key].name) tensor_info_inputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) tensor_info_outputs = {} for tmp_key in outputs: tmp = graph.get_tensor_by_name(outputs[tmp_key].name) tensor_info_outputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs=tensor_info_inputs, outputs=tensor_info_outputs, method_name=signature_constants.PREDICT_METHOD_NAME)) session_config = ConfigProto( allow_soft_placement=True, log_device_placement=True) saver = tf.train.Saver() with tf.Session(target=server.target if server else '') as sess: saver.restore(sess, checkpoint_path) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature, }, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), saver=saver, strip_default_attrs=True, clear_devices=True) builder.save() # remove temporary files Remove(embed_name_to_id_file) return export_dir def export_big_model_to_oss(export_dir, pipeline_config, oss_params, serving_input_fn, estimator, checkpoint_path, verbose): for key in oss_params: logging.info('%s: %s' % (key, oss_params[key])) write_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so') kv_module = tf.load_op_library(write_kv_lib_path) if not checkpoint_path: checkpoint_path = estimator_utils.latest_checkpoint( pipeline_config.model_dir) logging.info('checkpoint_path = %s' % checkpoint_path) server = None cluster = None if 'TF_CONFIG' in os.environ: # change chief to master tf_config = estimator_utils.chief_to_master() if tf_config['task']['type'] == 'ps': cluster = tf.train.ClusterSpec(tf_config['cluster']) server = tf.train.Server( cluster, job_name='ps', task_index=tf_config['task']['index']) server.join() elif tf_config['task']['type'] == 'master': if 'ps' in tf_config['cluster']: cluster = tf.train.ClusterSpec(tf_config['cluster']) server = tf.train.Server(cluster, job_name='master', task_index=0) server_target = server.target logging.info('server_target = %s' % server_target) serving_input = serving_input_fn() features = serving_input.features inputs = serving_input.receiver_tensors if cluster: logging.info('cluster = ' + str(cluster)) with tf.device( replica_device_setter( worker_device='/job:master/task:0', cluster=cluster)): outputs = estimator._export_model_fn(features, None, None, estimator.params).predictions meta_graph_def = export_meta_graph() meta_graph_def.meta_info_def.meta_graph_version = str(int(time.time())) oss_embedding_version = oss_params.get('oss_embedding_version', '') if not oss_embedding_version: meta_graph_def.meta_info_def.meta_graph_version =\ str(int(time.time())) else: meta_graph_def.meta_info_def.meta_graph_version = oss_embedding_version logging.info('meta_graph_version = %s' % meta_graph_def.meta_info_def.meta_graph_version) embed_var_parts = {} embed_norm_name = {} embed_spos = {} # pai embedding variable embedding_vars = {} norm_name_to_ids = {} for x in global_variables(): tf.logging.info('global var: %s %s %s' % (x.name, str(type(x)), x.device)) if 'EmbeddingVariable' in str(type(x)): norm_name, part_id = proto_util.get_norm_embed_name(x.name) norm_name_to_ids[norm_name] = 1 tmp_export = x.export() if x.device not in embedding_vars: embedding_vars[x.device] = [(norm_name, tmp_export.keys, tmp_export.values, part_id)] else: embedding_vars[x.device].append( (norm_name, tmp_export.keys, tmp_export.values, part_id)) elif '/embedding_weights:' in x.name or '/embedding_weights/part_' in x.name: norm_name, part_id = proto_util.get_norm_embed_name(x.name) norm_name_to_ids[norm_name] = 1 embed_norm_name[x] = norm_name if norm_name not in embed_var_parts: embed_var_parts[norm_name] = {part_id: x} else: embed_var_parts[norm_name][part_id] = x for tid, t in enumerate(norm_name_to_ids.keys()): norm_name_to_ids[t] = str(tid) for x in embed_norm_name: embed_norm_name[x] = norm_name_to_ids[embed_norm_name[x]] total_num = 0 for norm_name in embed_var_parts: parts = embed_var_parts[norm_name] spos = 0 part_ids = list(parts.keys()) part_ids.sort() total_num += len(part_ids) for part_id in part_ids: embed_spos[parts[part_id]] = spos spos += parts[part_id].get_shape()[0] oss_path = oss_params.get('oss_path', '') oss_endpoint = oss_params.get('oss_endpoint', '') oss_ak = oss_params.get('oss_ak', '') oss_sk = oss_params.get('oss_sk', '') logging.info('will export to oss: %s %s %s %s', oss_path, oss_endpoint, oss_ak, oss_sk) if oss_params.get('oss_write_kv', ''): # group embed by devices per_device_vars = {} for x in embed_norm_name: if x.device not in per_device_vars: per_device_vars[x.device] = [x] else: per_device_vars[x.device].append(x) all_write_res = [] for tmp_dev in per_device_vars: tmp_vars = per_device_vars[tmp_dev] with tf.device(tmp_dev): tmp_names = [embed_norm_name[v] for v in tmp_vars] tmp_spos = [np.array(embed_spos[v], dtype=np.int64) for v in tmp_vars] write_kv_res = kv_module.oss_write_kv( tmp_names, tmp_vars, tmp_spos, osspath=oss_path, endpoint=oss_endpoint, ak=oss_ak, sk=oss_sk, threads=oss_params.get('oss_threads', 5), timeout=5, expire=5, verbose=verbose) all_write_res.append(write_kv_res) for tmp_dev in embedding_vars: with tf.device(tmp_dev): tmp_vs = embedding_vars[tmp_dev] tmp_sparse_names = [norm_name_to_ids[x[0]] for x in tmp_vs] tmp_sparse_keys = [x[1] for x in tmp_vs] tmp_sparse_vals = [x[2] for x in tmp_vs] tmp_part_ids = [x[3] for x in tmp_vs] write_sparse_kv_res = kv_module.oss_write_sparse_kv( tmp_sparse_names, tmp_sparse_vals, tmp_sparse_keys, tmp_part_ids, osspath=oss_path, endpoint=oss_endpoint, ak=oss_ak, sk=oss_sk, version=meta_graph_def.meta_info_def.meta_graph_version, threads=oss_params.get('oss_threads', 5), verbose=verbose) all_write_res.append(write_sparse_kv_res) session_config = ConfigProto( allow_soft_placement=True, log_device_placement=False) chief_sess_creator = ChiefSessionCreator( master=server.target if server else '', checkpoint_filename_with_path=checkpoint_path, config=session_config) with tf.train.MonitoredSession( session_creator=chief_sess_creator, hooks=None, stop_grace_period_secs=120) as sess: dump_flags = sess.run(all_write_res) logging.info('write embedding to oss succeed: %s' % str(dump_flags)) else: logging.info('will skip write embedding to oss because ' 'oss_write_kv is set to 0.') # delete embedding_weights collections so that it could be re imported tmp_drop = [] for k in meta_graph_def.collection_def: v = meta_graph_def.collection_def[k] if len( v.node_list.value) > 0 and 'embedding_weights' in v.node_list.value[0]: tmp_drop.append(k) for k in tmp_drop: meta_graph_def.collection_def.pop(k) meta_graph_editor = MetaGraphEditor( os.path.join(easy_rec.ops_dir, 'libembed_op.so'), None, oss_path=oss_path, oss_endpoint=oss_endpoint, oss_ak=oss_ak, oss_sk=oss_sk, oss_timeout=oss_params.get('oss_timeout', 1500), meta_graph_def=meta_graph_def, norm_name_to_ids=norm_name_to_ids, incr_update_params=oss_params.get('incr_update', None), debug_dir=export_dir if verbose else '') meta_graph_editor.edit_graph_for_oss() tf.reset_default_graph() saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def) graph = tf.get_default_graph() embed_name_to_id_file = os.path.join(export_dir, 'embed_name_to_ids.txt') with GFile(embed_name_to_id_file, 'w') as fout: for tmp_norm_name in norm_name_to_ids: fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name])) ops.add_to_collection( ops.GraphKeys.ASSET_FILEPATHS, tf.constant( embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt')) if 'incr_update' in oss_params: dense_train_vars_path = os.path.join( os.path.dirname(checkpoint_path), constant.DENSE_UPDATE_VARIABLES) ops.add_to_collection( ops.GraphKeys.ASSET_FILEPATHS, tf.constant( dense_train_vars_path, dtype=tf.string, name=constant.DENSE_UPDATE_VARIABLES)) asset_file = 'incr_update.txt' asset_file_path = os.path.join(export_dir, asset_file) with GFile(asset_file_path, 'w') as fout: incr_update = oss_params['incr_update'] incr_update_json = {} if 'kafka' in incr_update: incr_update_json['storage'] = 'kafka' incr_update_json['kafka'] = json.loads( json_format.MessageToJson( incr_update['kafka'], preserving_proto_field_name=True)) elif 'datahub' in incr_update: incr_update_json['storage'] = 'datahub' incr_update_json['datahub'] = json.loads( json_format.MessageToJson( incr_update['datahub'], preserving_proto_field_name=True)) elif 'fs' in incr_update: incr_update_json['storage'] = 'fs' incr_update_json['fs'] = {'incr_save_dir': incr_update['fs'].mount_path} json.dump(incr_update_json, fout, indent=2) ops.add_to_collection( ops.GraphKeys.ASSET_FILEPATHS, tf.constant(asset_file_path, dtype=tf.string, name=asset_file)) export_dir = os.path.join(export_dir, meta_graph_def.meta_info_def.meta_graph_version) export_dir = io_util.fix_oss_dir(export_dir) logging.info('export_dir=%s' % export_dir) if Exists(export_dir): logging.info('will delete old dir: %s' % export_dir) DeleteRecursively(export_dir) builder = tf.saved_model.builder.SavedModelBuilder(export_dir) tensor_info_inputs = {} for tmp_key in inputs: tmp = graph.get_tensor_by_name(inputs[tmp_key].name) tensor_info_inputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) tensor_info_outputs = {} for tmp_key in outputs: tmp = graph.get_tensor_by_name(outputs[tmp_key].name) tensor_info_outputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs=tensor_info_inputs, outputs=tensor_info_outputs, method_name=signature_constants.PREDICT_METHOD_NAME)) if 'incr_update' in oss_params: incr_update_inputs = meta_graph_editor.sparse_update_inputs incr_update_outputs = meta_graph_editor.sparse_update_outputs incr_update_inputs.update(meta_graph_editor.dense_update_inputs) incr_update_outputs.update(meta_graph_editor.dense_update_outputs) tensor_info_incr_update_inputs = {} tensor_info_incr_update_outputs = {} for tmp_key in incr_update_inputs: tmp = graph.get_tensor_by_name(incr_update_inputs[tmp_key].name) tensor_info_incr_update_inputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) for tmp_key in incr_update_outputs: tmp = graph.get_tensor_by_name(incr_update_outputs[tmp_key].name) tensor_info_incr_update_outputs[tmp_key] = \ tf.saved_model.utils.build_tensor_info(tmp) incr_update_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs=tensor_info_incr_update_inputs, outputs=tensor_info_incr_update_outputs, method_name=signature_constants.PREDICT_METHOD_NAME)) else: incr_update_signature = None session_config = ConfigProto( allow_soft_placement=True, log_device_placement=True) saver = tf.train.Saver() with tf.Session(target=server.target if server else '') as sess: saver.restore(sess, checkpoint_path) main_op = tf.group([ Scaffold.default_local_init_op(), ops.get_collection(EMBEDDING_INITIALIZERS) ]) incr_update_sig_map = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature } if incr_update_signature is not None: incr_update_sig_map[INCR_UPDATE_SIGNATURE_KEY] = incr_update_signature builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map=incr_update_sig_map, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), saver=saver, main_op=main_op, strip_default_attrs=True, clear_devices=True) builder.save() # remove temporary files Remove(embed_name_to_id_file) return export_dir