ppo_ewma/train.py (113 lines of code) (raw):
import argparse
from mpi4py import MPI
from . import ppg
from . import torch_util as tu
from .impala_cnn import ImpalaEncoder
from . import logger
from .envs import get_venv
def train_fn(env_name="coinrun",
distribution_mode="hard",
arch="dual", # 'shared', 'detach', or 'dual'
# 'shared' = shared policy and value networks
# 'dual' = separate policy and value networks
# 'detach' = shared policy and value networks, but with the value function gradient detached during the policy phase to avoid interference
interacts_total=100_000_000,
num_envs=64,
nstep=256,
n_epoch_pi=1,
n_epoch_vf=1,
gamma=.999,
lambda_=0.95,
aux_lr=5e-4,
aux_beta1=0.9,
aux_beta2=0.999,
lr=5e-4,
beta1=0.9,
beta2=0.999,
nminibatch=8,
aux_mbsize=4,
clip_param=.2,
kl_penalty=0.0,
kl_ewma_decay=None,
n_aux_epochs=6,
n_pi=32,
beta_clone=1.0,
vf_true_weight=1.0,
adv_ewma_decay=0.0,
log_dir='/tmp/ppg',
log_new_eps=False,
comm=None,
staleness=0,
staleness_loss='decoupled',
imp_samp_max=100.0):
if comm is None:
comm = MPI.COMM_WORLD
tu.setup_dist(comm=comm)
tu.register_distributions_for_tree_util()
if log_dir is not None:
format_strs = ['csv', 'stdout'] if comm.Get_rank() == 0 else []
logger.configure(comm=comm, dir=log_dir, format_strs=format_strs)
venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode)
enc_fn = lambda obtype: ImpalaEncoder(
obtype.shape,
outsize=256,
chans=(16, 32, 32),
)
model = ppg.PhasicValueModel(venv.ob_space, venv.ac_space, enc_fn, arch=arch)
model.to(tu.dev())
logger.log(tu.format_model(model))
tu.sync_params(model.parameters())
name2coef = {"pol_distance": beta_clone, "vf_true": vf_true_weight}
ppg.learn(
venv=venv,
model=model,
interacts_total=interacts_total,
ppo_hps=dict(
lr=lr,
beta1=beta1,
beta2=beta2,
nstep=nstep,
γ=gamma,
λ=lambda_,
nminibatch=nminibatch,
n_epoch_vf=n_epoch_vf,
n_epoch_pi=n_epoch_pi,
clip_param=clip_param,
kl_penalty=kl_penalty,
adv_ewma_decay=adv_ewma_decay,
log_save_opts={"save_mode": "last", "log_new_eps": log_new_eps},
staleness=staleness,
staleness_loss=staleness_loss,
imp_samp_max=imp_samp_max
),
aux_lr=aux_lr,
aux_beta1=aux_beta1,
aux_beta2=aux_beta2,
aux_mbsize=aux_mbsize,
n_aux_epochs=n_aux_epochs,
n_pi=n_pi,
kl_ewma_decay=kl_ewma_decay,
name2coef=name2coef,
comm=comm,
)
def main():
parser = argparse.ArgumentParser(description='Process PPG training arguments.')
parser.add_argument('--env_name', type=str, default='coinrun')
parser.add_argument('--num_envs', type=int, default=64)
parser.add_argument('--n_epoch_pi', type=int, default=1)
parser.add_argument('--n_epoch_vf', type=int, default=1)
parser.add_argument('--n_aux_epochs', type=int, default=6)
parser.add_argument('--n_pi', type=int, default=32)
parser.add_argument('--clip_param', type=float, default=0.2)
parser.add_argument('--kl_penalty', type=float, default=0.0)
parser.add_argument('--arch', type=str, default='dual') # 'shared', 'detach', or 'dual'
args = parser.parse_args()
comm = MPI.COMM_WORLD
train_fn(
env_name=args.env_name,
num_envs=args.num_envs,
n_epoch_pi=args.n_epoch_pi,
n_epoch_vf=args.n_epoch_vf,
n_aux_epochs=args.n_aux_epochs,
n_pi=args.n_pi,
arch=args.arch,
comm=comm)
if __name__ == '__main__':
main()