understanding_rl_vision/rl_clarity/loading.py (239 lines of code) (raw):

import numpy as np import tensorflow as tf from contextlib import contextmanager import os import re import tempfile from lucid.modelzoo.vision_base import Model from lucid.misc.io.reading import read from lucid.scratch.rl_util.joblib_wrapper import load_joblib, save_joblib from .training import create_env, get_arch def load_params(params, *, sess): var_list = tf.global_variables() for name, var_value in params.items(): matching_vars = [var for var in var_list if var.name == name] if matching_vars: matching_vars[0].load(var_value, sess) def save_lucid_model(config, params, *, model_path, metadata_path): config = config.copy() config.pop("num_envs") library = config.get("library", "baselines") venv = create_env(1, **config) arch = get_arch(**config) with tf.Graph().as_default(), tf.Session() as sess: observation_space = venv.observation_space observations_placeholder = tf.placeholder( shape=(None,) + observation_space.shape, dtype=tf.float32 ) if library == "baselines": from baselines.common.policies import build_policy with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE): policy_fn = build_policy(venv, arch) policy = policy_fn( nbatch=None, nsteps=1, sess=sess, observ_placeholder=(observations_placeholder * 255), ) pd = policy.pd vf = policy.vf else: raise ValueError(f"Unsupported library: {library}") load_params(params, sess=sess) Model.save( model_path, input_name=observations_placeholder.op.name, output_names=[pd.logits.op.name, vf.op.name], image_shape=observation_space.shape, image_value_range=[0.0, 1.0], ) metadata = { "policy_logits_name": pd.logits.op.name, "value_function_name": vf.op.name, "env_name": config.get("env_name"), "gae_gamma": config.get("gamma"), "gae_lambda": config.get("lambda"), } env = venv while hasattr(env, "env") and (not hasattr(env, "combos")): env = env.env if hasattr(env, "combos"): metadata["action_combos"] = env.combos else: metadata["action_combos"] = None save_joblib(metadata, metadata_path) return {"model_bytes": read(model_path, cache=False, mode="rb"), **metadata} @contextmanager def get_step_fn(config, params, *, num_envs, full_resolution): config = config.copy() config.pop("num_envs") library = config.get("library", "baselines") venv = create_env(num_envs, **config) arch = get_arch(**config) with tf.Graph().as_default(), tf.Session() as sess: if library == "baselines": from baselines.common.policies import build_policy with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE): policy_fn = build_policy(venv, arch) policy = policy_fn(nbatch=venv.num_envs, nsteps=1, sess=sess) stepdata = { "ob": venv.reset(), "state": policy.initial_state, "first": np.ones((venv.num_envs,), bool), } if full_resolution: stepdata["ob_full"] = np.stack( [info["rgb"] for info in venv.env.get_info()], axis=0 ) def step_fn(): result = {"ob": stepdata["ob"], "first": stepdata["first"].astype(bool)} if full_resolution: result["ob_full"] = stepdata["ob_full"] result["ac"], _, stepdata["state"], _ = policy.step( stepdata["ob"], S=stepdata["state"], M=stepdata["first"].astype(float), ) ( stepdata["ob"], result["reward"], stepdata["first"], result["info"], ) = venv.step(result["ac"]) if full_resolution: stepdata["ob_full"] = np.stack( [info["rgb"] for info in result["info"]], axis=0 ) return result else: raise ValueError(f"Unsupported library: {library}") load_params(params, sess=sess) yield step_fn def save_observations( config, params, *, observations_path, num_envs, num_obs, obs_every, full_resolution ): with get_step_fn( config, params, num_envs=num_envs, full_resolution=full_resolution ) as step_fn: observations = [] if full_resolution: observations_full = [] for _ in range(num_obs): for _ in range(obs_every): step_result = step_fn() observations.append(step_result["ob"]) if full_resolution: observations_full.append(step_result["ob_full"]) observations = np.concatenate(observations, axis=0) if full_resolution: observations_full = np.concatenate(observations_full, axis=0) result = {"observations": observations} if full_resolution: result["observations_full"] = observations_full save_joblib(result, observations_path) return result def save_trajectories( config, params, *, trajectories_path, num_envs, num_steps, full_resolution ): with get_step_fn( config, params, num_envs=num_envs, full_resolution=full_resolution ) as step_fn: step_fn() trajectories = [step_fn() for _ in range(num_steps)] get_and_stack = lambda ds, key, axis=1: np.stack( [d[key] for d in ds], axis=axis ) result = { "observations": get_and_stack(trajectories, "ob"), "actions": get_and_stack(trajectories, "ac"), "rewards": get_and_stack(trajectories, "reward"), "firsts": get_and_stack(trajectories, "first"), } if full_resolution: result["observations_full"] = get_and_stack(trajectories, "ob_full") save_joblib(result, trajectories_path) return {"trajectories": result} def load( checkpoint_path, *, resample=True, model_path=None, metadata_path=None, trajectories_path=None, observations_path=None, trajectories_kwargs={}, observations_kwargs={}, full_resolution=False, temp_files=False, ): if temp_files: default_path = lambda suffix: tempfile.mkstemp(suffix=suffix)[1] else: path_stem = re.split(r"(?<=[^/])\.[^/\.]*$", checkpoint_path)[0] path_stem = os.path.join( os.path.dirname(path_stem), "rl-clarity", os.path.basename(path_stem) ) default_path = lambda suffix: path_stem + suffix if model_path is None: model_path = default_path(".model.pb") if metadata_path is None: metadata_path = default_path(".metadata.jd") if trajectories_path is None: trajectories_path = default_path(".trajectories.jd") if observations_path is None: observations_path = default_path(".observations.jd") if resample: trajectories_kwargs.setdefault("num_envs", 8) trajectories_kwargs.setdefault("num_steps", 512) observations_kwargs.setdefault("num_envs", 32) observations_kwargs.setdefault("num_obs", 128) observations_kwargs.setdefault("obs_every", 128) checkpoint_dict = load_joblib(checkpoint_path, cache=False) config = checkpoint_dict["args"] if full_resolution: config["render_human"] = True if config.get("use_lstm", 0): raise ValueError("Recurrent networks not yet supported by this interface.") params = checkpoint_dict["params"] config["coinrun_old_extra_actions"] = 0 if config.get("env_name") == "coinrun_old": # we may need to add extra actions depending on the size of the policy head policy_bias_keys = [ k for k in checkpoint_dict["params"] if k.endswith("pi/b:0") ] if policy_bias_keys: [policy_bias_key] = policy_bias_keys (num_actions,) = checkpoint_dict["params"][policy_bias_key].shape if num_actions == 9: config["coinrun_old_extra_actions"] = 2 return { **save_lucid_model( config, params, model_path=model_path, metadata_path=metadata_path ), **save_observations( config, params, observations_path=observations_path, num_envs=observations_kwargs["num_envs"], num_obs=observations_kwargs["num_obs"], obs_every=observations_kwargs["obs_every"], full_resolution=full_resolution, ), **save_trajectories( config, params, trajectories_path=trajectories_path, num_envs=trajectories_kwargs["num_envs"], num_steps=trajectories_kwargs["num_steps"], full_resolution=full_resolution, ), } else: observations = load_joblib(observations_path, cache=False) if not isinstance(observations, dict): observations = {"observations": observations} return { "model_bytes": read(model_path, cache=False, mode="rb"), **observations, "trajectories": load_joblib(trajectories_path, cache=False), **load_joblib(metadata_path, cache=False), }