in understanding_rl_vision/rl_clarity/training.py [0:0]
def train(comm=None, *, save_dir=None, **kwargs):
"""
Train a model using Baselines' PPO2, and to save a checkpoint file in the
required format.
There is one required kwarg: either env_name (for env_kind="procgen") or
env_id (for env_kind="atari").
Models for the paper were trained with 16 parallel MPI workers.
Note: this code has not been well-tested.
"""
kwargs.setdefault("env_kind", "procgen")
kwargs.setdefault("num_envs", 64)
kwargs.setdefault("learning_rate", 5e-4)
kwargs.setdefault("entropy_coeff", 0.01)
kwargs.setdefault("gamma", 0.999)
kwargs.setdefault("lambda", 0.95)
kwargs.setdefault("num_steps", 256)
kwargs.setdefault("num_minibatches", 8)
kwargs.setdefault("library", "baselines")
kwargs.setdefault("save_all", False)
kwargs.setdefault("ppo_epochs", 3)
kwargs.setdefault("clip_range", 0.2)
kwargs.setdefault("timesteps_per_proc", 1_000_000_000)
kwargs.setdefault("cnn", "clear")
kwargs.setdefault("use_lstm", 0)
kwargs.setdefault("stack_channels", "16_32_32")
kwargs.setdefault("emb_size", 256)
kwargs.setdefault("epsilon_greedy", 0.0)
kwargs.setdefault("reward_scale", 1.0)
kwargs.setdefault("frame_stack", 1)
kwargs.setdefault("use_sticky_actions", 0)
kwargs.setdefault("clip_vf", 1)
kwargs.setdefault("reward_processing", "none")
kwargs.setdefault("save_interval", 10)
if comm is None:
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
setup_mpi_gpus()
if save_dir is None:
save_dir = tempfile.mkdtemp(prefix="rl_clarity_train_")
create_env_kwargs = kwargs.copy()
num_envs = create_env_kwargs.pop("num_envs")
venv = create_env(num_envs, **create_env_kwargs)
library = kwargs["library"]
if library == "baselines":
reward_processing = kwargs["reward_processing"]
if reward_processing == "none":
pass
elif reward_processing == "clip":
venv = VecClipReward(venv=venv)
elif reward_processing == "normalize":
venv = VecNormalize(venv=venv, ob=False, per_env=False)
else:
raise ValueError(f"Unsupported reward processing: {reward_processing}")
scope = "ppo2_model"
def update_fn(update, params=None):
if rank == 0:
save_interval = kwargs["save_interval"]
if save_interval > 0 and update % save_interval == 0:
print("Saving...")
params = get_tf_params(scope)
save_path = save_data(
save_dir=save_dir,
args_dict=kwargs,
params=params,
step=(update if kwargs["save_all"] else None),
)
print(f"Saved to: {save_path}")
sess = create_tf_session()
sess.__enter__()
if kwargs["use_lstm"]:
raise ValueError("Recurrent networks not yet supported.")
arch = get_arch(**kwargs)
from baselines.ppo2 import ppo2
ppo2.learn(
env=venv,
network=arch,
total_timesteps=kwargs["timesteps_per_proc"],
save_interval=0,
nsteps=kwargs["num_steps"],
nminibatches=kwargs["num_minibatches"],
lam=kwargs["lambda"],
gamma=kwargs["gamma"],
noptepochs=kwargs["ppo_epochs"],
log_interval=1,
ent_coef=kwargs["entropy_coeff"],
mpi_rank_weight=1.0,
clip_vf=bool(kwargs["clip_vf"]),
comm=comm,
lr=kwargs["learning_rate"],
cliprange=kwargs["clip_range"],
update_fn=update_fn,
init_fn=None,
vf_coef=0.5,
max_grad_norm=0.5,
)
else:
raise ValueError(f"Unsupported library: {library}")
return save_dir