in diagnostics/model_test/verify_model.py [0:0]
def calculate_output(param_dict, select_device, input_example):
"""Calculate the output of the imported graph given the input.
Load the graph def from graph file on selected device, then get the tensors based on the input and output name from the graph,
then feed the input_example to the graph and retrieves the output vector.
Args:
param_dict: The dictionary contains all the user-input data in the json file.
select_device: "NGRAPH" or "CPU".
input_example: A map with key is the name of the input tensor, and value is the random generated example
Returns:
The output vector obtained from running the input_example through the graph.
"""
tf.compat.v1.reset_default_graph()
is_ckpt = False
if "pb_graph_location" in param_dict and "checkpoint_graph_location" in param_dict:
raise Exception(
"Only Graph or Checkpoint file can be specified, not both!")
if "pb_graph_location" in param_dict:
pb_filename = param_dict["pb_graph_location"]
elif "checkpoint_graph_location" in param_dict:
checkpoint_filename = param_dict["checkpoint_graph_location"]
is_ckpt = True
else:
raise Exception(
"Input graph file OR Input checkpoint file is required!")
output_tensor_name = param_dict["output_tensor_name"]
config = tf.compat.v1.ConfigProto(
inter_op_parallelism_threads=1, allow_soft_placement=True)
config_ngraph_enabled = ngraph_bridge.update_config(config)
sess = tf.compat.v1.Session(config=config_ngraph_enabled)
set_os_env(select_device)
# if checkpoint, then load checkpoint
if (is_ckpt):
meta_filename = checkpoint_filename + '.meta'
if not tf.io.gfile.exists(meta_filename):
raise Exception("Meta file does not exist")
else:
saver = tf.compat.v1.train.import_meta_graph(meta_filename)
if not tf.compat.v1.train.checkpoint_exists(checkpoint_filename):
raise Exception("Checkpoint with this prefix does not exist")
else:
saver.restore(sess, checkpoint_filename)
print("Model restored: " + select_device)
graph = tf.compat.v1.get_default_graph()
#if graph, then load graph
else:
graph_def = tf.compat.v1.GraphDef()
if pb_filename.endswith("pbtxt"):
with open(pb_filename, "r") as f:
text_format.Merge(f.read(), graph_def)
else:
with open(pb_filename, "rb") as f:
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
sess = tf.compat.v1.Session(graph=graph, config=config)
# if no outputs are specified, then compare for all tensors
if len(output_tensor_name) == 0:
output_tensor_name = sum(
[[j.name for j in i.outputs] for i in graph.get_operations()], [])
# Create the tensor to its corresponding example map
tensor_to_example_map = {}
for item in input_example:
t = graph.get_tensor_by_name(item)
tensor_to_example_map[t] = input_example[item]
tensors = []
skipped_tensors = []
output_tensor = [graph.get_tensor_by_name(i) for i in output_tensor_name]
for name in output_tensor_name:
try:
output_tensor = sess.run(name, feed_dict=tensor_to_example_map)
tensors.append(output_tensor)
except Exception as e:
skipped_tensors.append(name)
return tensors, output_tensor_name, skipped_tensors