easy_rec/python/tools/edit_lookup_graph.py (102 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import logging
import os
import sys
import tensorflow as tf
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.lib.io.file_io import file_exists
from tensorflow.python.lib.io.file_io import recursive_create_dir
from tensorflow.python.platform.gfile import GFile
import easy_rec
from easy_rec.python.utils.meta_graph_editor import MetaGraphEditor
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
if __name__ == '__main__':
"""Replace the default embedding_lookup ops with self defined embedding lookup ops.
The data are now stored in redis, for lookup, it is to retrieve the
embedding vectors by {version}_{embed_name}_{embed_id}.
Example:
python -m easy_rec.python.tools.edit_lookup_graph
--saved_model_dir rtp_large_embedding_export/1604304644
--output_dir ./after_edit_save
--test_data_path data/test/rtp/xys_cxr_fg_sample_test2_with_lbl.txt
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--saved_model_dir', type=str, default=None, help='saved model dir')
parser.add_argument('--output_dir', type=str, default=None, help='output dir')
parser.add_argument(
'--redis_url', type=str, default='127.0.0.1:6379', help='redis url')
parser.add_argument(
'--redis_passwd', type=str, default='', help='redis password')
parser.add_argument('--time_out', type=int, default=1500, help='timeout')
parser.add_argument(
'--test_data_path', type=str, default='', help='test data path')
parser.add_argument('--verbose', action='store_true', default=False)
args = parser.parse_args()
logging.info('saved_model_dir: %s' % args.saved_model_dir)
if not os.path.exists(os.path.join(args.saved_model_dir, 'saved_model.pb')):
logging.error('saved_model.pb does not exist in %s' % args.saved_model_dir)
sys.exit(1)
logging.info('output_dir: %s' % args.output_dir)
logging.info('redis_url: %s' % args.redis_url)
lookup_lib_path = os.path.join(easy_rec.ops_dir, 'libkv_lookup.so')
logging.info('lookup_lib_path: %s' % lookup_lib_path)
if not file_exists(args.output_dir):
recursive_create_dir(args.output_dir)
meta_graph_editor = MetaGraphEditor(
lookup_lib_path,
args.saved_model_dir,
args.redis_url,
args.redis_passwd,
args.time_out,
meta_graph_def=None,
debug_dir=args.output_dir if args.verbose else '')
meta_graph_editor.edit_graph()
meta_graph_version = meta_graph_editor.meta_graph_version
if meta_graph_version == '':
export_ts = [
x for x in args.saved_model_dir.split('/') if x != '' and x is not None
]
meta_graph_version = export_ts[-1]
# import edit graph
tf.reset_default_graph()
saver = tf.train.import_meta_graph(meta_graph_editor._meta_graph_def)
embed_name_to_id_file = os.path.join(args.output_dir, 'embed_name_to_ids.txt')
with GFile(embed_name_to_id_file, 'w') as fout:
for tmp_norm_name in meta_graph_editor._embed_name_to_ids:
fout.write(
'%s\t%s\n' %
(tmp_norm_name, meta_graph_editor._embed_name_to_ids[tmp_norm_name]))
tf.add_to_collection(
tf.GraphKeys.ASSET_FILEPATHS,
tf.constant(
embed_name_to_id_file, dtype=tf.string, name='embed_name_to_ids.txt'))
graph = tf.get_default_graph()
inputs = meta_graph_editor.signature_def.inputs
inputs_map = {}
for name, tensor in inputs.items():
logging.info('model inputs: %s => %s' % (name, tensor.name))
inputs_map[name] = graph.get_tensor_by_name(tensor.name)
outputs = meta_graph_editor.signature_def.outputs
outputs_map = {}
for name, tensor in outputs.items():
logging.info('model outputs: %s => %s' % (name, tensor.name))
outputs_map[name] = graph.get_tensor_by_name(tensor.name)
with tf.Session() as sess:
saver.restore(sess, args.saved_model_dir + '/variables/variables')
output_dir = os.path.join(args.output_dir, meta_graph_version)
tf.saved_model.simple_save(
sess, output_dir, inputs=inputs_map, outputs=outputs_map)
# the meta_graph_version could not be passed via existing interfaces
# so we could only write it by the raw methods
saved_model = saved_model_pb2.SavedModel()
with GFile(os.path.join(output_dir, 'saved_model.pb'), 'rb') as fin:
saved_model.ParseFromString(fin.read())
saved_model.meta_graphs[
0].meta_info_def.meta_graph_version = meta_graph_editor.meta_graph_version
with GFile(os.path.join(output_dir, 'saved_model.pb'), 'wb') as fout:
fout.write(saved_model.SerializeToString())
logging.info('save output to %s' % output_dir)
if args.test_data_path:
with GFile(args.test_data_path, 'r') as fin:
feature_vals = []
for line_str in fin:
line_str = line_str.strip()
line_toks = line_str.split('')
line_toks = line_toks[1:]
feature_vals.append(''.join(line_toks))
if len(feature_vals) >= 32:
break
out_vals = sess.run(
outputs_map, feed_dict={inputs_map['features']: feature_vals})
logging.info('test_data probs:' + str(out_vals))