ma_policy/load_policy.py (63 lines of code) (raw):

import os import numpy as np import tensorflow as tf import logging import sys import traceback import cloudpickle as pickle from ma_policy.ma_policy import MAPolicy def shape_list(x): ''' deal with dynamic shape in tensorflow cleanly ''' ps = x.get_shape().as_list() ts = tf.shape(x) return [ts[i] if ps[i] is None else ps[i] for i in range(len(ps))] def replace_base_scope(var_name, new_base_scope): split = var_name.split('/') split[0] = new_base_scope return os.path.normpath('/'.join(split)) def load_variables(policy, weights): weights = {os.path.normpath(key): value for key, value in weights.items()} weights = {replace_base_scope(key, policy.scope): value for key, value in weights.items()} assign_ops = [] for var in policy.get_variables(): var_name = os.path.normpath(var.name) if var_name not in weights: logging.warning(f"{var_name} was not found in weights dict. This will be reinitialized.") tf.get_default_session().run(var.initializer) else: try: assert np.all(np.array(shape_list(var)) == np.array(weights[var_name].shape)) assign_ops.append(var.assign(weights[var_name])) except Exception: traceback.print_exc(file=sys.stdout) print(f"Error assigning weights of shape {weights[var_name].shape} to {var}") sys.exit() tf.get_default_session().run(assign_ops) def load_policy(path, env=None, scope='policy'): ''' Load a policy. Args: path (string): policy path env (Gym.Env): This will update the observation space of the policy that is returned scope (string): The base scope for the policy variables ''' # TODO this will probably need to be changed when trying to run policy on GPU if tf.get_default_session() is None: tf_config = tf.ConfigProto( inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) sess = tf.Session(config=tf_config) sess.__enter__() policy_dict = dict(np.load(path)) policy_fn_and_args_raw = pickle.loads(policy_dict['policy_fn_and_args']) policy_args = policy_fn_and_args_raw['args'] policy_args['scope'] = scope if env is not None: policy_args['ob_space'] = env.observation_space policy_args['ac_space'] = env.action_space policy = MAPolicy(**policy_args) del policy_dict['policy_fn_and_args'] load_variables(policy, policy_dict) return policy