ppo_ewma/train.py (113 lines of code) (raw):

import argparse from mpi4py import MPI from . import ppg from . import torch_util as tu from .impala_cnn import ImpalaEncoder from . import logger from .envs import get_venv def train_fn(env_name="coinrun", distribution_mode="hard", arch="dual", # 'shared', 'detach', or 'dual' # 'shared' = shared policy and value networks # 'dual' = separate policy and value networks # 'detach' = shared policy and value networks, but with the value function gradient detached during the policy phase to avoid interference interacts_total=100_000_000, num_envs=64, nstep=256, n_epoch_pi=1, n_epoch_vf=1, gamma=.999, lambda_=0.95, aux_lr=5e-4, aux_beta1=0.9, aux_beta2=0.999, lr=5e-4, beta1=0.9, beta2=0.999, nminibatch=8, aux_mbsize=4, clip_param=.2, kl_penalty=0.0, kl_ewma_decay=None, n_aux_epochs=6, n_pi=32, beta_clone=1.0, vf_true_weight=1.0, adv_ewma_decay=0.0, log_dir='/tmp/ppg', log_new_eps=False, comm=None, staleness=0, staleness_loss='decoupled', imp_samp_max=100.0): if comm is None: comm = MPI.COMM_WORLD tu.setup_dist(comm=comm) tu.register_distributions_for_tree_util() if log_dir is not None: format_strs = ['csv', 'stdout'] if comm.Get_rank() == 0 else [] logger.configure(comm=comm, dir=log_dir, format_strs=format_strs) venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode) enc_fn = lambda obtype: ImpalaEncoder( obtype.shape, outsize=256, chans=(16, 32, 32), ) model = ppg.PhasicValueModel(venv.ob_space, venv.ac_space, enc_fn, arch=arch) model.to(tu.dev()) logger.log(tu.format_model(model)) tu.sync_params(model.parameters()) name2coef = {"pol_distance": beta_clone, "vf_true": vf_true_weight} ppg.learn( venv=venv, model=model, interacts_total=interacts_total, ppo_hps=dict( lr=lr, beta1=beta1, beta2=beta2, nstep=nstep, γ=gamma, λ=lambda_, nminibatch=nminibatch, n_epoch_vf=n_epoch_vf, n_epoch_pi=n_epoch_pi, clip_param=clip_param, kl_penalty=kl_penalty, adv_ewma_decay=adv_ewma_decay, log_save_opts={"save_mode": "last", "log_new_eps": log_new_eps}, staleness=staleness, staleness_loss=staleness_loss, imp_samp_max=imp_samp_max ), aux_lr=aux_lr, aux_beta1=aux_beta1, aux_beta2=aux_beta2, aux_mbsize=aux_mbsize, n_aux_epochs=n_aux_epochs, n_pi=n_pi, kl_ewma_decay=kl_ewma_decay, name2coef=name2coef, comm=comm, ) def main(): parser = argparse.ArgumentParser(description='Process PPG training arguments.') parser.add_argument('--env_name', type=str, default='coinrun') parser.add_argument('--num_envs', type=int, default=64) parser.add_argument('--n_epoch_pi', type=int, default=1) parser.add_argument('--n_epoch_vf', type=int, default=1) parser.add_argument('--n_aux_epochs', type=int, default=6) parser.add_argument('--n_pi', type=int, default=32) parser.add_argument('--clip_param', type=float, default=0.2) parser.add_argument('--kl_penalty', type=float, default=0.0) parser.add_argument('--arch', type=str, default='dual') # 'shared', 'detach', or 'dual' args = parser.parse_args() comm = MPI.COMM_WORLD train_fn( env_name=args.env_name, num_envs=args.num_envs, n_epoch_pi=args.n_epoch_pi, n_epoch_vf=args.n_epoch_vf, n_aux_epochs=args.n_aux_epochs, n_pi=args.n_pi, arch=args.arch, comm=comm) if __name__ == '__main__': main()