in spinup/utils/logx.py [0:0]
def restore_tf_graph(sess, fpath):
"""
Loads graphs saved by Logger.
Will output a dictionary whose keys and values are from the 'inputs'
and 'outputs' dict you specified with logger.setup_tf_saver().
Args:
sess: A Tensorflow session.
fpath: Filepath to save directory.
Returns:
A dictionary mapping from keys to tensors in the computation graph
loaded from ``fpath``.
"""
tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
fpath
)
model_info = joblib.load(osp.join(fpath, 'model_info.pkl'))
graph = tf.get_default_graph()
model = dict()
model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['inputs'].items()})
model.update({k: graph.get_tensor_by_name(v) for k,v in model_info['outputs'].items()})
return model