ppo_ewma/ppo.py (305 lines of code) (raw):
"""
Mostly copied from ppo.py but with some extra options added that are relevant to phasic
"""
import numpy as np
import torch as th
from queue import Queue
from mpi4py import MPI
from functools import partial
from .tree_util import tree_map, tree_multimap
from . import torch_util as tu
from .log_save_helper import LogSaveHelper
from .minibatch_optimize import minibatch_optimize
from .roller import Roller
from .reward_normalizer import RewardNormalizer
import math
from . import logger
INPUT_KEYS = {"ob", "ac", "first", "logp", "rec_logp", "vtarg", "adv", "state_in"}
def compute_gae(
*,
vpred: "(th.Tensor[1, float]) value predictions",
reward: "(th.Tensor[1, float]) rewards",
first: "(th.Tensor[1, bool]) mark beginning of episodes",
γ: "(float)",
λ: "(float)"
):
orig_device = vpred.device
assert orig_device == reward.device == first.device
vpred, reward, first = (x.cpu() for x in (vpred, reward, first))
first = first.to(dtype=th.float32)
assert first.dim() == 2
nenv, nstep = reward.shape
assert vpred.shape == first.shape == (nenv, nstep + 1)
adv = th.zeros(nenv, nstep, dtype=th.float32)
lastgaelam = 0
for t in reversed(range(nstep)):
notlast = 1.0 - first[:, t + 1]
nextvalue = vpred[:, t + 1]
# notlast: whether next timestep is from the same episode
delta = reward[:, t] + notlast * γ * nextvalue - vpred[:, t]
adv[:, t] = lastgaelam = delta + notlast * γ * λ * lastgaelam
vtarg = vpred[:, :-1] + adv
return adv.to(device=orig_device), vtarg.to(device=orig_device)
def log_vf_stats(comm, **kwargs):
logger.logkv(
"VFStats/EV", tu.explained_variance(kwargs["vpred"], kwargs["vtarg"], comm)
)
for key in ["vpred", "vtarg", "adv"]:
logger.logkv_mean(f"VFStats/{key.capitalize()}Mean", kwargs[key].mean())
logger.logkv_mean(f"VFStats/{key.capitalize()}Std", kwargs[key].std())
def compute_advantage(model, seg, γ, λ, comm=None, adv_moments=None):
comm = comm or MPI.COMM_WORLD
finalob, finalfirst = seg["finalob"], seg["finalfirst"]
vpredfinal = model.v(finalob, finalfirst, seg["finalstate"])
reward = seg["reward"]
logger.logkv("Misc/FrameRewMean", reward.mean())
adv, vtarg = compute_gae(
γ=γ,
λ=λ,
reward=reward,
vpred=th.cat([seg["vpred"], vpredfinal[:, None]], dim=1),
first=th.cat([seg["first"], finalfirst[:, None]], dim=1),
)
log_vf_stats(comm, adv=adv, vtarg=vtarg, vpred=seg["vpred"])
seg["vtarg"] = vtarg
adv_mean, adv_var = tu.mpi_moments(comm, adv)
if adv_moments is not None:
adv_moments.update(adv_mean, adv_var, adv.numel() * comm.size)
adv_mean, adv_var = adv_moments.moments()
logger.logkv_mean("VFStats/AdvEwmaMean", adv_mean)
logger.logkv_mean("VFStats/AdvEwmaStd", math.sqrt(adv_var))
seg["adv"] = (adv - adv_mean) / (math.sqrt(adv_var) + 1e-8)
def tree_cat(trees):
return tree_multimap(lambda *xs: th.cat(xs, dim=0), *trees)
def recompute_logp(*, model, seg, mbsize):
b = tu.batch_len(seg)
with th.no_grad():
logps = []
for inds in th.arange(b).split(mbsize):
mb = tu.tree_slice(seg, inds)
pd, _, _, _ = model(mb["ob"], mb["first"], mb["state_in"])
logp = tu.sum_nonbatch(pd.log_prob(mb["ac"]))
logps.append(logp)
seg["rec_logp"] = tree_cat(logps)
def compute_losses(
model,
model_ewma,
ob,
ac,
first,
logp,
rec_logp,
vtarg,
adv,
state_in,
clip_param,
vfcoef,
entcoef,
kl_penalty,
imp_samp_max,
):
losses = {}
diags = {}
pd, vpred, aux, _state_out = model(ob=ob, first=first, state_in=state_in)
newlogp = tu.sum_nonbatch(pd.log_prob(ac))
if model_ewma is not None:
pd_ewma, _vpred_ewma, _, _state_out_ewma = model_ewma(
ob=ob, first=first, state_in=state_in
)
rec_logp = tu.sum_nonbatch(pd_ewma.log_prob(ac))
# prob ratio for KL / clipping based on a (possibly) recomputed logp
logratio = newlogp - rec_logp
# stale data can give rise to very large importance sampling ratios,
# especially when using the wrong behavior policy,
# so we need to clip them for numerical stability.
# this can introduce bias, but by default we only clip extreme ratios
# to minimize this effect
logp_adj = logp
if imp_samp_max > 0:
logp_adj = th.max(logp, newlogp.detach() - math.log(imp_samp_max))
# because of the large importance sampling ratios again,
# we need to handle the ratios in log space for numerical stability
pg_losses = -adv * th.exp(newlogp - logp_adj)
if clip_param > 0:
clipped_logratio = th.clamp(logratio, math.log(1.0 - clip_param), math.log(1.0 + clip_param))
pg_losses2 = -adv * th.exp(clipped_logratio + rec_logp - logp_adj)
pg_losses = th.max(pg_losses, pg_losses2)
diags["entropy"] = entropy = tu.sum_nonbatch(pd.entropy()).mean()
diags["negent"] = -entropy * entcoef
diags["pg"] = pg_losses.mean()
diags["pi_kl"] = kl_penalty * 0.5 * (logratio ** 2).mean()
losses["pi"] = diags["negent"] + diags["pg"] + diags["pi_kl"]
losses["vf"] = vfcoef * ((vpred - vtarg) ** 2).mean()
with th.no_grad():
if clip_param > 0:
diags["clipfrac"] = th.logical_or(
logratio < math.log(1.0 - clip_param),
logratio > math.log(1.0 + clip_param),
).float().mean()
diags["approxkl"] = 0.5 * (logratio ** 2).mean()
if imp_samp_max > 0:
diags["imp_samp_clipfrac"] = (newlogp - logp > math.log(imp_samp_max)).float().mean()
return losses, diags
class EwmaMoments:
"""
Calculate rolling moments using EWMAs.
"""
def __init__(self, ewma_decay):
self.ewma_decay = ewma_decay
self.w = 0.0
self.ww = 0.0 # sum of squared weights
self.wsum = 0.0
self.wsumsq = 0.0
def update(self, mean, var, count, *, ddof=0):
self.w *= self.ewma_decay
self.ww *= self.ewma_decay ** 2
self.wsum *= self.ewma_decay
self.wsumsq *= self.ewma_decay
self.w += count
self.ww += count
self.wsum += mean * count
self.wsumsq += (count - ddof) * var + count * mean ** 2
def moments(self, *, ddof=0):
mean = self.wsum / self.w
# unbiased weighted sample variance:
# https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights
var = (self.wsumsq - self.wsum ** 2 / self.w) / (self.w - ddof * self.ww / self.w)
return mean, var
def learn(
*,
venv: "(VecEnv) vectorized environment",
model: "(ppo.PpoModel)",
model_ewma: "(ppg.EwmaModel) alternate model used for clipping or the KL penalty",
interacts_total: "(float) total timesteps of interaction" = float("inf"),
nstep: "(int) number of serial timesteps" = 256,
γ: "(float) discount" = 0.99,
λ: "(float) GAE parameter" = 0.95,
clip_param: "(float) PPO parameter for clipping prob ratio" = 0.2,
vfcoef: "(float) value function coefficient" = 0.5,
entcoef: "(float) entropy coefficient" = 0.01,
nminibatch: "(int) number of minibatches to break epoch of data into" = 4,
n_epoch_vf: "(int) number of epochs to use when training the value function" = 1,
n_epoch_pi: "(int) number of epochs to use when training the policy" = 1,
lr: "(float) Adam learning rate" = 5e-4,
beta1: "(float) Adam beta1" = 0.9,
beta2: "(float) Adam beta2" = 0.999,
default_loss_weights: "(dict) default_loss_weights" = {},
store_segs: "(bool) whether or not to store segments in a buffer" = True,
verbose: "(bool) print per-epoch loss stats" = True,
log_save_opts: "(dict) passed into LogSaveHelper" = {},
rnorm: "(bool) reward normalization" = True,
kl_penalty: "(int) weight of the KL penalty, which can be used in place of clipping" = 0,
adv_ewma_decay: "(float) EWMA decay for advantage normalization" = 0.0,
grad_weight: "(float) relative weight of this worker's gradients" = 1,
comm: "(MPI.Comm) MPI communicator" = None,
callbacks: "(seq of function(dict)->bool) to run each update" = (),
learn_state: "dict with optional keys {'opts', 'roller', 'lsh', 'reward_normalizer', 'curr_interact_count', 'seg_buf', 'segs_delayed', 'adv_moments'}" = None,
staleness: "(int) number of iterations by which to make data artificially stale, for experimentation" = 0,
staleness_loss: "(str) one of 'decoupled', 'behavior' or 'proximal', only used if staleness > 0" = "decoupled",
imp_samp_max: "(float) value at which to clip importance sampling ratio" = 100.0,
):
if comm is None:
comm = MPI.COMM_WORLD
learn_state = learn_state or {}
ic_per_step = venv.num * comm.size * nstep
opt_keys = (
["pi", "vf"] if (n_epoch_pi != n_epoch_vf) else ["pi"]
) # use separate optimizers when n_epoch_pi != n_epoch_vf
params = list(model.parameters())
opts = learn_state.get("opts") or {
k: th.optim.Adam(params, lr=lr, betas=(beta1, beta2))
for k in opt_keys
}
tu.sync_params(params)
if rnorm:
reward_normalizer = learn_state.get("reward_normalizer") or RewardNormalizer(venv.num)
else:
reward_normalizer = None
def get_weight(k):
return default_loss_weights[k] if k in default_loss_weights else 1.0
def train_with_losses_and_opt(loss_keys, opt, **arrays):
losses, diags = compute_losses(
model,
model_ewma=model_ewma,
entcoef=entcoef,
kl_penalty=kl_penalty,
clip_param=clip_param,
vfcoef=vfcoef,
imp_samp_max=imp_samp_max,
**arrays,
)
loss = sum([losses[k] * get_weight(k) for k in loss_keys])
opt.zero_grad()
loss.backward()
tu.warn_no_gradient(model, "PPO")
tu.sync_grads(params, grad_weight=grad_weight)
diags = {k: v.detach() for (k, v) in diags.items()}
opt.step()
if "pi" in loss_keys and model_ewma is not None:
model_ewma.update()
diags.update({f"loss_{k}": v.detach() for (k, v) in losses.items()})
return diags
def train_pi(**arrays):
return train_with_losses_and_opt(["pi"], opts["pi"], **arrays)
def train_vf(**arrays):
return train_with_losses_and_opt(["vf"], opts["vf"], **arrays)
def train_pi_and_vf(**arrays):
return train_with_losses_and_opt(["pi", "vf"], opts["pi"], **arrays)
roller = learn_state.get("roller") or Roller(
act_fn=model.act,
venv=venv,
initial_state=model.initial_state(venv.num),
keep_buf=100,
keep_non_rolling=log_save_opts.get("log_new_eps", False),
)
lsh = learn_state.get("lsh") or LogSaveHelper(
ic_per_step=ic_per_step, model=model, comm=comm, **log_save_opts
)
callback_exit = False # Does callback say to exit loop?
curr_interact_count = learn_state.get("curr_interact_count") or 0
curr_iteration = 0
seg_buf = learn_state.get("seg_buf") or []
segs_delayed = learn_state.get("segs_delayed") or Queue(maxsize=staleness + 1)
adv_moments = learn_state.get("adv_moments") or EwmaMoments(adv_ewma_decay)
while curr_interact_count < interacts_total and not callback_exit:
seg = roller.multi_step(nstep)
lsh.gather_roller_stats(roller)
if staleness > 0:
segs_delayed.put(seg)
if not segs_delayed.full():
continue
seg = segs_delayed.get()
if staleness_loss == "behavior":
seg["rec_logp"] = seg["logp"]
else:
recompute_logp(model=model, seg=seg, mbsize=4)
if staleness_loss == "proximal":
seg["logp"] = seg["rec_logp"]
else:
seg["rec_logp"] = seg["logp"]
if rnorm:
seg["reward"] = reward_normalizer(seg["reward"], seg["first"])
compute_advantage(model, seg, γ, λ, comm=comm, adv_moments=adv_moments)
if store_segs:
seg_buf.append(tree_map(lambda x: x.cpu(), seg))
with logger.profile_kv("optimization"):
# when n_epoch_pi != n_epoch_vf, we perform separate policy and vf epochs with separate optimizers
if n_epoch_pi != n_epoch_vf:
minibatch_optimize(
train_vf,
{k: seg[k] for k in INPUT_KEYS},
nminibatch=nminibatch,
comm=comm,
nepoch=n_epoch_vf,
verbose=verbose,
)
train_fn = train_pi
else:
train_fn = train_pi_and_vf
epoch_stats = minibatch_optimize(
train_fn,
{k: seg[k] for k in INPUT_KEYS},
nminibatch=nminibatch,
comm=comm,
nepoch=n_epoch_pi,
verbose=verbose,
)
for (k, v) in epoch_stats[-1].items():
logger.logkv("Opt/" + k, v)
lsh()
curr_interact_count += ic_per_step
curr_iteration += 1
for callback in callbacks:
callback_exit = callback_exit or bool(callback(locals()))
return dict(
opts=opts,
roller=roller,
lsh=lsh,
reward_normalizer=reward_normalizer,
curr_interact_count=curr_interact_count,
seg_buf=seg_buf,
segs_delayed=segs_delayed,
adv_moments=adv_moments,
)