understanding_rl_vision/rl_clarity/training.py (392 lines of code) (raw):

import time import os import tempfile import numpy as np from mpi4py import MPI import gym from baselines.common.vec_env import ( VecEnv, VecEnvWrapper, VecFrameStack, VecMonitor, VecNormalize, ) from baselines.common.mpi_util import setup_mpi_gpus from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.atari_wrappers import make_atari, wrap_deepmind from lucid.scratch.rl_util import save_joblib PROCGEN_ENV_NAMES = [ "bigfish", "bossfight", "caveflyer", "chaser", "climber", "coinrun", "dodgeball", "fruitbot", "heist", "jumper", "leaper", "maze", "miner", "ninja", "plunder", "starpilot", ] PROCGEN_KWARG_KEYS = [ "num_levels", "start_level", "fixed_difficulty", "use_easy_jump", "paint_vel_info", "use_generated_assets", "use_monochrome_assets", "restrict_themes", "use_backgrounds", "plain_assets", "is_high_difficulty", "is_uniform_difficulty", "distribution_mode", "use_sequential_levels", "fix_background", "physics_mode", "debug_mode", "center_agent", "env_name", "game_type", "game_mechanics", "sample_game_mechanics", "render_human", ] ATARI_ENV_IDS = [ "AirRaid", "Alien", "Amidar", "Assault", "Asterix", "Asteroids", "Atlantis", "BankHeist", "BattleZone", "BeamRider", "Berzerk", "Bowling", "Boxing", "Breakout", "Carnival", "Centipede", "ChopperCommand", "CrazyClimber", "DemonAttack", "DoubleDunk", "ElevatorAction", "Enduro", "FishingDerby", "Freeway", "Frostbite", "Gopher", "Gravitar", "Hero", "IceHockey", "Jamesbond", "JourneyEscape", "Kangaroo", "Krull", "KungFuMaster", "MontezumaRevenge", "MsPacman", "NameThisGame", "Phoenix", "Pitfall", "Pong", "Pooyan", "PrivateEye", "Qbert", "Riverraid", "RoadRunner", "Robotank", "Seaquest", "Skiing", "Solaris", "SpaceInvaders", "StarGunner", "Tennis", "TimePilot", "Tutankham", "UpNDown", "Venture", "VideoPinball", "WizardOfWor", "YarsRevenge", "Zaxxon", ] ATARI_ENV_DICT = {envid.lower(): envid for envid in ATARI_ENV_IDS} class EpsilonGreedy(VecEnvWrapper): """ Overide with random actions with probability epsilon Args: epsilon: the probability actions will be overridden with random actions """ def __init__(self, venv: VecEnv, epsilon: float): super().__init__(venv) assert isinstance(self.action_space, gym.spaces.Discrete) or isinstance( self.action_space, gym.spaces.MultiBinary ) self.epsilon = epsilon def reset(self): return self.venv.reset() def step_async(self, actions): mask = np.random.uniform(size=self.num_envs) < self.epsilon new_actions = np.array( [ self.action_space.sample() if mask[i] else actions[i] for i in range(self.num_envs) ] ) self.venv.step_async(new_actions) def step_wait(self): return self.venv.step_wait() class VecRewardScale(VecEnvWrapper): """ Add `task_id` to the corresponding info dict of each environment in the provided VecEnv Args: venv: A set of environments task_ids: A list of task_ids corresponding to each environment in `venv` """ def __init__(self, venv: VecEnv, scale: float): super().__init__(venv) self._scale = scale def reset(self): return self.venv.reset() def step_wait(self): obs, rews, dones, infos = self.venv.step_wait() rews = rews * self._scale return obs, rews, dones, infos # our internal version of CoinRun old ended up with 2 additional actions, so # the pre-trained models require this wrapper. class VecExtraActions(VecEnvWrapper): def __init__(self, venv, *, extra_actions, default_action): assert isinstance(venv.action_space, gym.spaces.Discrete) super().__init__( venv, action_space=gym.spaces.Discrete(venv.action_space.n + extra_actions) ) self.default_action = default_action def reset(self): return self.venv.reset() def step_async(self, actions): actions = actions.copy() for i in range(len(actions)): if actions[i] >= self.venv.action_space.n: actions[i] = self.default_action self.venv.step_async(actions) def step_wait(self): return self.venv.step_wait() # hack to fix a bug caused by observations being modified in-place class VecShallowCopy(VecEnvWrapper): def step_async(self, actions): actions = actions.copy() self.venv.step_async(actions) def reset(self): obs = self.venv.reset() return obs.copy() def step_wait(self): obs, rews, dones, infos = self.venv.step_wait() return obs.copy(), rews.copy(), dones.copy(), infos.copy() coinrun_initialized = False def create_env( num_envs, *, env_kind="procgen", epsilon_greedy=0.0, reward_scale=1.0, frame_stack=1, use_sticky_actions=0, coinrun_old_extra_actions=0, **kwargs, ): if env_kind == "procgen": env_kwargs = {k: v for k, v in kwargs.items() if v is not None} env_name = env_kwargs.pop("env_name") if env_name == "coinrun_old": import coinrun from coinrun.config import Config Config.initialize_args(use_cmd_line_args=False, **env_kwargs) global coinrun_initialized if not coinrun_initialized: coinrun.init_args_and_threads() coinrun_initialized = True venv = coinrun.make("standard", num_envs) if coinrun_old_extra_actions > 0: venv = VecExtraActions( venv, extra_actions=coinrun_old_extra_actions, default_action=0 ) else: from procgen import ProcgenGym3Env import gym3 env_kwargs = { k: v for k, v in env_kwargs.items() if k in PROCGEN_KWARG_KEYS } env = ProcgenGym3Env(num_envs, env_name=env_name, **env_kwargs) env = gym3.ExtractDictObWrapper(env, "rgb") venv = gym3.ToBaselinesVecEnv(env) elif env_kind == "atari": game_version = "v0" if use_sticky_actions == 1 else "v4" def make_atari_env(lower_env_id, num_env): env_id = ATARI_ENV_DICT[lower_env_id] + f"NoFrameskip-{game_version}" def make_atari_env_fn(): env = make_atari(env_id) env = wrap_deepmind(env, frame_stack=False, clip_rewards=False) return env return SubprocVecEnv([make_atari_env_fn for i in range(num_env)]) lower_env_id = kwargs["env_id"] venv = make_atari_env(lower_env_id, num_envs) else: raise ValueError(f"Unsupported env_kind: {env_kind}") if frame_stack > 1: venv = VecFrameStack(venv=venv, nstack=frame_stack) if reward_scale != 1: venv = VecRewardScale(venv, reward_scale) venv = VecMonitor(venv=venv, filename=None, keep_buf=100) if epsilon_greedy > 0: venv = EpsilonGreedy(venv, epsilon_greedy) venv = VecShallowCopy(venv) return venv def get_arch( *, library="baselines", cnn="clear", use_lstm=0, stack_channels="16_32_32", emb_size=256, **kwargs, ): stack_channels = [int(x) for x in stack_channels.split("_")] if library == "baselines": if cnn == "impala": from baselines.common.models import build_impala_cnn conv_fn = lambda x: build_impala_cnn( x, depths=stack_channels, emb_size=emb_size ) elif cnn == "nature": from baselines.common.models import nature_cnn conv_fn = nature_cnn elif cnn == "clear": from lucid.scratch.rl_util.arch import clear_cnn conv_fn = clear_cnn else: raise ValueError(f"Unsupported cnn: {cnn}") if use_lstm: from baselines.common.models import cnn_lstm arch = cnn_lstm(nlstm=256, conv_fn=conv_fn) else: arch = conv_fn else: raise ValueError(f"Unsupported library: {library}") return arch def create_tf_session(): """ Create a TensorFlow session """ import tensorflow as tf config = tf.ConfigProto() config.gpu_options.allow_growth = True return tf.Session(config=config) def get_tf_params(scope): """ Get a dictionary of parameters from TensorFlow for the specified scope """ import tensorflow as tf from baselines.common.tf_util import get_session sess = get_session() allvars = tf.trainable_variables(scope) nonopt_vars = [ v for v in allvars if all(veto not in v.name for veto in ["optimizer", "kbuf", "vbuf"]) ] name2var = {v.name: v for v in nonopt_vars} return sess.run(name2var) def save_data(*, save_dir, args_dict, params, step=None, extra={}): """ Save the global config object as well as the current model params to a local file """ data_dict = dict(args=args_dict, params=params, extra=extra, time=time.time()) step_str = "" if step is None else f"-{step}" save_path = os.path.join(save_dir, f"checkpoint{step_str}.jd") if "://" not in save_dir: os.makedirs(save_dir, exist_ok=True) save_joblib(data_dict, save_path) return save_path class VecClipReward(VecEnvWrapper): def reset(self): return self.venv.reset() def step_wait(self): """Bin reward to {+1, 0, -1} by its sign.""" obs, rews, dones, infos = self.venv.step_wait() return obs, np.sign(rews), dones, infos def train(comm=None, *, save_dir=None, **kwargs): """ Train a model using Baselines' PPO2, and to save a checkpoint file in the required format. There is one required kwarg: either env_name (for env_kind="procgen") or env_id (for env_kind="atari"). Models for the paper were trained with 16 parallel MPI workers. Note: this code has not been well-tested. """ kwargs.setdefault("env_kind", "procgen") kwargs.setdefault("num_envs", 64) kwargs.setdefault("learning_rate", 5e-4) kwargs.setdefault("entropy_coeff", 0.01) kwargs.setdefault("gamma", 0.999) kwargs.setdefault("lambda", 0.95) kwargs.setdefault("num_steps", 256) kwargs.setdefault("num_minibatches", 8) kwargs.setdefault("library", "baselines") kwargs.setdefault("save_all", False) kwargs.setdefault("ppo_epochs", 3) kwargs.setdefault("clip_range", 0.2) kwargs.setdefault("timesteps_per_proc", 1_000_000_000) kwargs.setdefault("cnn", "clear") kwargs.setdefault("use_lstm", 0) kwargs.setdefault("stack_channels", "16_32_32") kwargs.setdefault("emb_size", 256) kwargs.setdefault("epsilon_greedy", 0.0) kwargs.setdefault("reward_scale", 1.0) kwargs.setdefault("frame_stack", 1) kwargs.setdefault("use_sticky_actions", 0) kwargs.setdefault("clip_vf", 1) kwargs.setdefault("reward_processing", "none") kwargs.setdefault("save_interval", 10) if comm is None: comm = MPI.COMM_WORLD rank = comm.Get_rank() setup_mpi_gpus() if save_dir is None: save_dir = tempfile.mkdtemp(prefix="rl_clarity_train_") create_env_kwargs = kwargs.copy() num_envs = create_env_kwargs.pop("num_envs") venv = create_env(num_envs, **create_env_kwargs) library = kwargs["library"] if library == "baselines": reward_processing = kwargs["reward_processing"] if reward_processing == "none": pass elif reward_processing == "clip": venv = VecClipReward(venv=venv) elif reward_processing == "normalize": venv = VecNormalize(venv=venv, ob=False, per_env=False) else: raise ValueError(f"Unsupported reward processing: {reward_processing}") scope = "ppo2_model" def update_fn(update, params=None): if rank == 0: save_interval = kwargs["save_interval"] if save_interval > 0 and update % save_interval == 0: print("Saving...") params = get_tf_params(scope) save_path = save_data( save_dir=save_dir, args_dict=kwargs, params=params, step=(update if kwargs["save_all"] else None), ) print(f"Saved to: {save_path}") sess = create_tf_session() sess.__enter__() if kwargs["use_lstm"]: raise ValueError("Recurrent networks not yet supported.") arch = get_arch(**kwargs) from baselines.ppo2 import ppo2 ppo2.learn( env=venv, network=arch, total_timesteps=kwargs["timesteps_per_proc"], save_interval=0, nsteps=kwargs["num_steps"], nminibatches=kwargs["num_minibatches"], lam=kwargs["lambda"], gamma=kwargs["gamma"], noptepochs=kwargs["ppo_epochs"], log_interval=1, ent_coef=kwargs["entropy_coeff"], mpi_rank_weight=1.0, clip_vf=bool(kwargs["clip_vf"]), comm=comm, lr=kwargs["learning_rate"], cliprange=kwargs["clip_range"], update_fn=update_fn, init_fn=None, vf_coef=0.5, max_grad_norm=0.5, ) else: raise ValueError(f"Unsupported library: {library}") return save_dir