in easy_rec/python/utils/export_big_model.py [0:0]
def export_big_model(export_dir, pipeline_config, redis_params,
serving_input_fn, estimator, checkpoint_path, verbose):
for key in redis_params:
logging.info('%s: %s' % (key, redis_params[key]))
redis_cache_names = []
for feature_config in pipeline_config.feature_configs:
if feature_config.is_cache:
if feature_config.feature_name:
redis_cache_names.append(feature_config.feature_name)
else:
redis_cache_names.append(feature_config.input_names[0])
logging.info('The list of cache names: %s' % ','.join(redis_cache_names))
write_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libembed_op.so')
kv_module = tf.load_op_library(write_kv_lib_path)
try:
sparse_kv_lib_path = os.path.join(easy_rec.ops_dir, 'libwrite_sparse_kv.so')
sparse_kv_module = tf.load_op_library(sparse_kv_lib_path)
except Exception as ex:
logging.warning('load libwrite_sparse_kv.so failed: %s' % str(ex))
sparse_kv_module = None
if not checkpoint_path:
checkpoint_path = estimator_utils.latest_checkpoint(
pipeline_config.model_dir)
logging.info('checkpoint_path = %s' % checkpoint_path)
server = None
cluster = None
if 'TF_CONFIG' in os.environ:
# change chief to master
tf_config = estimator_utils.chief_to_master()
if tf_config['task']['type'] == 'ps':
cluster = tf.train.ClusterSpec(tf_config['cluster'])
server = tf.train.Server(
cluster, job_name='ps', task_index=tf_config['task']['index'])
server.join()
elif tf_config['task']['type'] == 'master':
if 'ps' in tf_config['cluster']:
cluster = tf.train.ClusterSpec(tf_config['cluster'])
server = tf.train.Server(cluster, job_name='master', task_index=0)
server_target = server.target
logging.info('server_target = %s' % server_target)
serving_input = serving_input_fn()
features = serving_input.features
inputs = serving_input.receiver_tensors
if cluster:
logging.info('cluster = ' + str(cluster))
with tf.device(
replica_device_setter(
worker_device='/job:master/task:0', cluster=cluster)):
outputs = estimator._export_model_fn(features, None, None,
estimator.params).predictions
meta_graph_def = export_meta_graph()
redis_embedding_version = redis_params.get('redis_embedding_version', '')
if not redis_embedding_version:
meta_graph_def.meta_info_def.meta_graph_version =\
str(int(time.time()))
else:
meta_graph_def.meta_info_def.meta_graph_version = redis_embedding_version
logging.info('meta_graph_version = %s' %
meta_graph_def.meta_info_def.meta_graph_version)
embed_var_parts = {}
embed_norm_name = {}
embed_spos = {}
# pai embedding variable
embedding_vars = {}
norm_name_to_ids = {}
for x in global_variables():
if 'EmbeddingVariable' in str(type(x)):
norm_name, part_id = proto_util.get_norm_embed_name(x.name)
norm_name_to_ids[norm_name] = 1
tmp_export = x.export()
if x.device not in embedding_vars:
embedding_vars[x.device] = [(norm_name, tmp_export.keys,
tmp_export.values)]
else:
embedding_vars[x.device].append(
(norm_name, tmp_export.keys, tmp_export.values))
elif '/embedding_weights:' in x.name or '/embedding_weights/part_' in x.name:
norm_name, part_id = proto_util.get_norm_embed_name(x.name)
norm_name_to_ids[norm_name] = 1
embed_norm_name[x] = norm_name
if norm_name not in embed_var_parts:
embed_var_parts[norm_name] = {part_id: x}
else:
embed_var_parts[norm_name][part_id] = x
for tid, t in enumerate(norm_name_to_ids.keys()):
norm_name_to_ids[t] = str(tid)
is_cache_from_redis = [ # noqa: F841
proto_util.is_cache_from_redis(x, redis_cache_names)
for x in norm_name_to_ids
]
for x in embed_norm_name:
embed_norm_name[x] = norm_name_to_ids[embed_norm_name[x]]
total_num = 0
for norm_name in embed_var_parts:
parts = embed_var_parts[norm_name]
spos = 0
part_ids = list(parts.keys())
part_ids.sort()
total_num += len(part_ids)
for part_id in part_ids:
embed_spos[parts[part_id]] = spos
spos += parts[part_id].get_shape()[0]
redis_url = redis_params.get('redis_url', '')
redis_passwd = redis_params.get('redis_passwd', '')
logging.info('will export to redis: %s %s' % (redis_url, redis_passwd))
if redis_params.get('redis_write_kv', ''):
# group embed by devices
per_device_vars = {}
for x in embed_norm_name:
if x.device not in per_device_vars:
per_device_vars[x.device] = [x]
else:
per_device_vars[x.device].append(x)
all_write_res = []
for tmp_dev in per_device_vars:
tmp_vars = per_device_vars[tmp_dev]
with tf.device(tmp_dev):
tmp_names = [embed_norm_name[v] for v in tmp_vars]
tmp_spos = [np.array(embed_spos[v], dtype=np.int64) for v in tmp_vars]
write_kv_res = kv_module.write_kv(
tmp_names,
tmp_vars,
tmp_spos,
url=redis_url,
password=redis_passwd,
timeout=redis_params.get('redis_timeout', 1500),
version=meta_graph_def.meta_info_def.meta_graph_version,
threads=redis_params.get('redis_threads', 5),
batch_size=redis_params.get('redis_batch_size', 32),
expire=redis_params.get('redis_expire', 24),
verbose=verbose)
all_write_res.append(write_kv_res)
for tmp_dev in embedding_vars:
with tf.device(tmp_dev):
tmp_vs = embedding_vars[tmp_dev]
tmp_sparse_names = [norm_name_to_ids[x[0]] for x in tmp_vs]
tmp_sparse_keys = [x[1] for x in tmp_vs]
tmp_sparse_vals = [x[2] for x in tmp_vs]
write_sparse_kv_res = sparse_kv_module.write_sparse_kv(
tmp_sparse_names,
tmp_sparse_vals,
tmp_sparse_keys,
url=redis_url,
password=redis_passwd,
timeout=redis_params.get('redis_timeout', 1500),
version=meta_graph_def.meta_info_def.meta_graph_version,
threads=redis_params.get('redis_threads', 5),
batch_size=redis_params.get('redis_batch_size', 32),
expire=redis_params.get('redis_expire', 24),
verbose=verbose)
all_write_res.append(write_sparse_kv_res)
session_config = ConfigProto(
allow_soft_placement=True, log_device_placement=False)
chief_sess_creator = ChiefSessionCreator(
master=server.target if server else '',
checkpoint_filename_with_path=checkpoint_path,
config=session_config)
with tf.train.MonitoredSession(
session_creator=chief_sess_creator,
hooks=None,
stop_grace_period_secs=120) as sess:
dump_flags = sess.run(all_write_res)
logging.info('write embedding to redis succeed: %s' % str(dump_flags))
else:
logging.info('will skip write embedding to redis because '
'redis_write_kv is set to 0.')
# delete embedding_weights collections so that it could be re imported
tmp_drop = []
for k in meta_graph_def.collection_def:
v = meta_graph_def.collection_def[k]
if len(
v.node_list.value) > 0 and 'embedding_weights' in v.node_list.value[0]:
tmp_drop.append(k)
for k in tmp_drop:
meta_graph_def.collection_def.pop(k)
meta_graph_editor = MetaGraphEditor(
os.path.join(easy_rec.ops_dir, 'libembed_op.so'),
None,
redis_url,
redis_passwd,
redis_timeout=redis_params.get('redis_timeout', 600),
redis_cache_names=redis_cache_names,
meta_graph_def=meta_graph_def,
norm_name_to_ids=norm_name_to_ids,
debug_dir=export_dir if verbose else '')
meta_graph_editor.edit_graph()
tf.reset_default_graph()
saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def)
graph = tf.get_default_graph()
embed_name_to_id_file = os.path.join(export_dir, 'embed_name_to_ids.txt')
with GFile(embed_name_to_id_file, 'w') as fout:
for tmp_norm_name in norm_name_to_ids:
fout.write('%s\t%s\n' % (tmp_norm_name, norm_name_to_ids[tmp_norm_name]))
ops.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
export_dir = os.path.join(export_dir,
meta_graph_def.meta_info_def.meta_graph_version)
export_dir = io_util.fix_oss_dir(export_dir)
logging.info('export_dir=%s' % export_dir)
if Exists(export_dir):
logging.info('will delete old dir: %s' % export_dir)
DeleteRecursively(export_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
tensor_info_inputs = {}
for tmp_key in inputs:
tmp = graph.get_tensor_by_name(inputs[tmp_key].name)
tensor_info_inputs[tmp_key] = \
tf.saved_model.utils.build_tensor_info(tmp)
tensor_info_outputs = {}
for tmp_key in outputs:
tmp = graph.get_tensor_by_name(outputs[tmp_key].name)
tensor_info_outputs[tmp_key] = \
tf.saved_model.utils.build_tensor_info(tmp)
signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=signature_constants.PREDICT_METHOD_NAME))
session_config = ConfigProto(
allow_soft_placement=True, log_device_placement=True)
saver = tf.train.Saver()
with tf.Session(target=server.target if server else '') as sess:
saver.restore(sess, checkpoint_path)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature,
},
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
saver=saver,
strip_default_attrs=True,
clear_devices=True)
builder.save()
# remove temporary files
Remove(embed_name_to_id_file)
return export_dir