def compute_advantage()

in ppo_ewma/ppo.py [0:0]


def compute_advantage(model, seg, γ, λ, comm=None, adv_moments=None):
    comm = comm or MPI.COMM_WORLD
    finalob, finalfirst = seg["finalob"], seg["finalfirst"]
    vpredfinal = model.v(finalob, finalfirst, seg["finalstate"])
    reward = seg["reward"]
    logger.logkv("Misc/FrameRewMean", reward.mean())
    adv, vtarg = compute_gae(
        γ=γ,
        λ=λ,
        reward=reward,
        vpred=th.cat([seg["vpred"], vpredfinal[:, None]], dim=1),
        first=th.cat([seg["first"], finalfirst[:, None]], dim=1),
    )
    log_vf_stats(comm, adv=adv, vtarg=vtarg, vpred=seg["vpred"])
    seg["vtarg"] = vtarg
    adv_mean, adv_var = tu.mpi_moments(comm, adv)
    if adv_moments is not None:
        adv_moments.update(adv_mean, adv_var, adv.numel() * comm.size)
        adv_mean, adv_var = adv_moments.moments()
        logger.logkv_mean("VFStats/AdvEwmaMean", adv_mean)
        logger.logkv_mean("VFStats/AdvEwmaStd", math.sqrt(adv_var))
    seg["adv"] = (adv - adv_mean) / (math.sqrt(adv_var) + 1e-8)