in ppo_ewma/ppo.py [0:0]
def compute_advantage(model, seg, γ, λ, comm=None, adv_moments=None):
comm = comm or MPI.COMM_WORLD
finalob, finalfirst = seg["finalob"], seg["finalfirst"]
vpredfinal = model.v(finalob, finalfirst, seg["finalstate"])
reward = seg["reward"]
logger.logkv("Misc/FrameRewMean", reward.mean())
adv, vtarg = compute_gae(
γ=γ,
λ=λ,
reward=reward,
vpred=th.cat([seg["vpred"], vpredfinal[:, None]], dim=1),
first=th.cat([seg["first"], finalfirst[:, None]], dim=1),
)
log_vf_stats(comm, adv=adv, vtarg=vtarg, vpred=seg["vpred"])
seg["vtarg"] = vtarg
adv_mean, adv_var = tu.mpi_moments(comm, adv)
if adv_moments is not None:
adv_moments.update(adv_mean, adv_var, adv.numel() * comm.size)
adv_mean, adv_var = adv_moments.moments()
logger.logkv_mean("VFStats/AdvEwmaMean", adv_mean)
logger.logkv_mean("VFStats/AdvEwmaStd", math.sqrt(adv_var))
seg["adv"] = (adv - adv_mean) / (math.sqrt(adv_var) + 1e-8)