def train_fn()

in ppo_ewma/train.py [0:0]


def train_fn(env_name="coinrun",
    distribution_mode="hard",
    arch="dual",  # 'shared', 'detach', or 'dual'
    # 'shared' = shared policy and value networks
    # 'dual' = separate policy and value networks
    # 'detach' = shared policy and value networks, but with the value function gradient detached during the policy phase to avoid interference
    interacts_total=100_000_000,
    num_envs=64,
    nstep=256,
    n_epoch_pi=1,
    n_epoch_vf=1,
    gamma=.999,
    lambda_=0.95,
    aux_lr=5e-4,
    aux_beta1=0.9,
    aux_beta2=0.999,
    lr=5e-4,
    beta1=0.9,
    beta2=0.999,
    nminibatch=8,
    aux_mbsize=4,
    clip_param=.2,
    kl_penalty=0.0,
    kl_ewma_decay=None,
    n_aux_epochs=6,
    n_pi=32,
    beta_clone=1.0,
    vf_true_weight=1.0,
    adv_ewma_decay=0.0,
    log_dir='/tmp/ppg',
    log_new_eps=False,
    comm=None,
    staleness=0,
    staleness_loss='decoupled',
    imp_samp_max=100.0):
    if comm is None:
        comm = MPI.COMM_WORLD
    tu.setup_dist(comm=comm)
    tu.register_distributions_for_tree_util()

    if log_dir is not None:
        format_strs = ['csv', 'stdout'] if comm.Get_rank() == 0 else []
        logger.configure(comm=comm, dir=log_dir, format_strs=format_strs)

    venv = get_venv(num_envs=num_envs, env_name=env_name, distribution_mode=distribution_mode)

    enc_fn = lambda obtype: ImpalaEncoder(
        obtype.shape,
        outsize=256,
        chans=(16, 32, 32),
    )
    model = ppg.PhasicValueModel(venv.ob_space, venv.ac_space, enc_fn, arch=arch)

    model.to(tu.dev())
    logger.log(tu.format_model(model))
    tu.sync_params(model.parameters())

    name2coef = {"pol_distance": beta_clone, "vf_true": vf_true_weight}

    ppg.learn(
        venv=venv,
        model=model,
        interacts_total=interacts_total,
        ppo_hps=dict(
            lr=lr,
            beta1=beta1,
            beta2=beta2,
            nstep=nstep,
            γ=gamma,
            λ=lambda_,
            nminibatch=nminibatch,
            n_epoch_vf=n_epoch_vf,
            n_epoch_pi=n_epoch_pi,
            clip_param=clip_param,
            kl_penalty=kl_penalty,
            adv_ewma_decay=adv_ewma_decay,
            log_save_opts={"save_mode": "last", "log_new_eps": log_new_eps},
            staleness=staleness,
            staleness_loss=staleness_loss,
            imp_samp_max=imp_samp_max
        ),
        aux_lr=aux_lr,
        aux_beta1=aux_beta1,
        aux_beta2=aux_beta2,
        aux_mbsize=aux_mbsize,
        n_aux_epochs=n_aux_epochs,
        n_pi=n_pi,
        kl_ewma_decay=kl_ewma_decay,
        name2coef=name2coef,
        comm=comm,
    )