in ma_policy/load_policy.py [0:0]
def load_policy(path, env=None, scope='policy'):
'''
Load a policy.
Args:
path (string): policy path
env (Gym.Env): This will update the observation space of the
policy that is returned
scope (string): The base scope for the policy variables
'''
# TODO this will probably need to be changed when trying to run policy on GPU
if tf.get_default_session() is None:
tf_config = tf.ConfigProto(
inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1)
sess = tf.Session(config=tf_config)
sess.__enter__()
policy_dict = dict(np.load(path))
policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args'])
policy_args = policy_fn_and_args_raw['args']
policy_args['scope'] = scope
if env is not None:
policy_args['ob_space'] = env.observation_space
policy_args['ac_space'] = env.action_space
policy = MAPolicy(**policy_args)
del policy_dict['policy_fn_and_args']
load_variables(policy, policy_dict)
return policy