def _try_loading_model()

in tcav/model.py [0:0]


  def _try_loading_model(self, model_path):
    """ Load model from model_path.

    TF models are often saved in one of the three major formats:
      1) Checkpoints with ckpt.meta, ckpt.data, and ckpt.index.
      2) SavedModel format with saved_model.pb and variables/.
      3) Frozen graph in .pb or .pbtxt format.
    When model_path is specified, model is loaded in one of the
    three formats depending on the model_path. When model_path is
    ommitted, child wrapper is responsible for loading the model.
    """
    try:
      self.sess = tf.compat.v1.Session(graph=tf.Graph())
      with self.sess.graph.as_default():
        if tf.io.gfile.isdir(model_path):
          ckpt = tf.train.latest_checkpoint(model_path)
          if ckpt:
            tf.compat.v1.logging.info('Loading from the latest checkpoint.')
            saver = tf.compat.v1.train.import_meta_graph(ckpt + '.meta')
            saver.restore(self.sess, ckpt)
          else:
            tf.compat.v1.logging.info('Loading from SavedModel dir.')
            tf.compat.v1.saved_model.loader.load(self.sess, ['serve'], model_path)
        else:
          input_graph_def = tf.compat.v1.GraphDef()
          if model_path.endswith('.pb'):
            tf.compat.v1.logging.info('Loading from frozen binary graph.')
            with tf.io.gfile.GFile(model_path, 'rb') as f:
              input_graph_def.ParseFromString(f.read())
          else:
            tf.compat.v1.logging.info('Loading from frozen text graph.')
            with tf.io.gfile.GFile(model_path) as f:
              text_format.Parse(f.read(), input_graph_def)
          tf.import_graph_def(input_graph_def)
          self.import_prefix = True

    except Exception as e:
      template = 'An exception of type {0} occurred ' \
                 'when trying to load model from {1}. ' \
                 'Arguments:\n{2!r}'
      tf.compat.v1.logging.warn(template.format(type(e).__name__, model_path, e.args))