in tcav/model.py [0:0]
def _try_loading_model(self, model_path):
""" Load model from model_path.
TF models are often saved in one of the three major formats:
1) Checkpoints with ckpt.meta, ckpt.data, and ckpt.index.
2) SavedModel format with saved_model.pb and variables/.
3) Frozen graph in .pb or .pbtxt format.
When model_path is specified, model is loaded in one of the
three formats depending on the model_path. When model_path is
ommitted, child wrapper is responsible for loading the model.
"""
try:
self.sess = tf.compat.v1.Session(graph=tf.Graph())
with self.sess.graph.as_default():
if tf.io.gfile.isdir(model_path):
ckpt = tf.train.latest_checkpoint(model_path)
if ckpt:
tf.compat.v1.logging.info('Loading from the latest checkpoint.')
saver = tf.compat.v1.train.import_meta_graph(ckpt + '.meta')
saver.restore(self.sess, ckpt)
else:
tf.compat.v1.logging.info('Loading from SavedModel dir.')
tf.compat.v1.saved_model.loader.load(self.sess, ['serve'], model_path)
else:
input_graph_def = tf.compat.v1.GraphDef()
if model_path.endswith('.pb'):
tf.compat.v1.logging.info('Loading from frozen binary graph.')
with tf.io.gfile.GFile(model_path, 'rb') as f:
input_graph_def.ParseFromString(f.read())
else:
tf.compat.v1.logging.info('Loading from frozen text graph.')
with tf.io.gfile.GFile(model_path) as f:
text_format.Parse(f.read(), input_graph_def)
tf.import_graph_def(input_graph_def)
self.import_prefix = True
except Exception as e:
template = 'An exception of type {0} occurred ' \
'when trying to load model from {1}. ' \
'Arguments:\n{2!r}'
tf.compat.v1.logging.warn(template.format(type(e).__name__, model_path, e.args))