def load_policy()

in ma_policy/load_policy.py [0:0]


def load_policy(path, env=None, scope='policy'):
    '''
        Load a policy.
        Args:
            path (string): policy path
            env (Gym.Env): This will update the observation space of the
                policy that is returned
            scope (string): The base scope for the policy variables
    '''
    # TODO this will probably need to be changed when trying to run policy on GPU
    if tf.get_default_session() is None:
        tf_config = tf.ConfigProto(
            inter_op_parallelism_threads=1,
            intra_op_parallelism_threads=1)
        sess = tf.Session(config=tf_config)
        sess.__enter__()

    policy_dict = dict(np.load(path))
    policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args'])
    policy_args = policy_fn_and_args_raw['args']
    policy_args['scope'] = scope

    if env is not None:
        policy_args['ob_space'] = env.observation_space
        policy_args['ac_space'] = env.action_space

    policy = MAPolicy(**policy_args)
    del policy_dict['policy_fn_and_args']

    load_variables(policy, policy_dict)
    return policy