ppo_ewma/ppg.py (255 lines of code) (raw):

from functools import partial from copy import deepcopy from . import ppo from . import logger import torch as th import itertools from . import torch_util as tu from torch import distributions as td from .distr_builder import distr_builder from mpi4py import MPI from .tree_util import tree_map, tree_reduce import operator def sum_nonbatch(logprob_tree): """ sums over nonbatch dimensions and over all leaves of the tree use with nested action spaces, which require Product distributions """ return tree_reduce(operator.add, tree_map(tu.sum_nonbatch, logprob_tree)) class PpoModel(th.nn.Module): def forward(self, ob, first, state_in) -> "pd, vpred, aux, state_out": raise NotImplementedError @tu.no_grad def act(self, ob, first, state_in): pd, vpred, _, state_out = self( ob=tree_map(lambda x: x[:, None], ob), first=first[:, None], state_in=state_in, ) ac = pd.sample() logp = sum_nonbatch(pd.log_prob(ac)) return ( tree_map(lambda x: x[:, 0], ac), state_out, dict(vpred=vpred[:, 0], logp=logp[:, 0]), ) @tu.no_grad def v(self, ob, first, state_in): _pd, vpred, _, _state_out = self( ob=tree_map(lambda x: x[:, None], ob), first=first[:, None], state_in=state_in, ) return vpred[:, 0] class PhasicModel(PpoModel): def forward(self, ob, first, state_in) -> "pd, vpred, aux, state_out": raise NotImplementedError def compute_aux_loss(self, aux, mb): raise NotImplementedError def initial_state(self, batchsize): raise NotImplementedError def aux_keys(self) -> "list of keys needed in mb dict for compute_aux_loss": raise NotImplementedError def set_aux_phase(self, is_aux_phase: bool): "sometimes you want to modify the model, e.g. add a stop gradient" class PhasicValueModel(PhasicModel): def __init__( self, obtype, actype, enc_fn, arch="dual", # shared, detach, dual ): super().__init__() detach_value_head = False vf_keys = None pi_key = "pi" if arch == "shared": true_vf_key = "pi" elif arch == "detach": true_vf_key = "pi" detach_value_head = True elif arch == "dual": true_vf_key = "vf" else: assert False vf_keys = vf_keys or [true_vf_key] self.pi_enc = enc_fn(obtype) self.pi_key = pi_key self.true_vf_key = true_vf_key self.vf_keys = vf_keys self.enc_keys = list(set([pi_key] + vf_keys)) self.detach_value_head = detach_value_head pi_outsize, self.make_distr = distr_builder(actype) for k in self.enc_keys: self.set_encoder(k, enc_fn(obtype)) for k in self.vf_keys: lastsize = self.get_encoder(k).codetype.size self.set_vhead(k, tu.NormedLinear(lastsize, 1, scale=0.1)) lastsize = self.get_encoder(self.pi_key).codetype.size self.pi_head = tu.NormedLinear(lastsize, pi_outsize, scale=0.1) self.aux_vf_head = tu.NormedLinear(lastsize, 1, scale=0.1) def compute_aux_loss(self, aux, seg): vtarg = seg["vtarg"] return { "vf_aux": 0.5 * ((aux["vpredaux"] - vtarg) ** 2).mean(), "vf_true": 0.5 * ((aux["vpredtrue"] - vtarg) ** 2).mean(), } def reshape_x(self, x): b, t = x.shape[:2] x = x.reshape(b, t, -1) return x def get_encoder(self, key): return getattr(self, key + "_enc") def set_encoder(self, key, enc): setattr(self, key + "_enc", enc) def get_vhead(self, key): return getattr(self, key + "_vhead") def set_vhead(self, key, layer): setattr(self, key + "_vhead", layer) def forward(self, ob, first, state_in): state_out = {} x_out = {} for k in self.enc_keys: x_out[k], state_out[k] = self.get_encoder(k)(ob, first, state_in[k]) x_out[k] = self.reshape_x(x_out[k]) pi_x = x_out[self.pi_key] pivec = self.pi_head(pi_x) pd = self.make_distr(pivec) aux = {} for k in self.vf_keys: if self.detach_value_head: x_out[k] = x_out[k].detach() aux[k] = self.get_vhead(k)(x_out[k])[..., 0] vfvec = aux[self.true_vf_key] aux.update({"vpredaux": self.aux_vf_head(pi_x)[..., 0], "vpredtrue": vfvec}) return pd, vfvec, aux, state_out def initial_state(self, batchsize): return {k: self.get_encoder(k).initial_state(batchsize) for k in self.enc_keys} def aux_keys(self): return ["vtarg"] class EwmaModel(PpoModel): """ An EWMA-lagged copy of a PpoModel. """ def __init__(self, model, ewma_decay): super().__init__() self.model = model self.ewma_decay = ewma_decay self.model_ewma = deepcopy(model) self.total_weight = 1 def forward(self, *args, **kwargs): with th.no_grad(): return self.model_ewma(*args, **kwargs) def update(self, decay=None): if decay is None: decay = self.ewma_decay new_total_weight = decay * self.total_weight + 1 decayed_weight_ratio = decay * self.total_weight / new_total_weight for p, p_ewma in zip(self.model.parameters(), self.model_ewma.parameters()): p_ewma.data.mul_(decayed_weight_ratio).add_(p.data / new_total_weight) self.total_weight = new_total_weight def reset(self): self.update(decay=0) def make_minibatches(segs, mbsize): """ Yield one epoch of minibatch over the dataset described by segs Each minibatch mixes data between different segs """ if mbsize < 1: nchunks = int(round(1 / mbsize)) assert th.isclose( th.tensor(1 / mbsize), th.tensor(float(nchunks)) ), "mbsize must be an integer or 1 divided by an integer" mbsize = 1 else: nchunks = 1 nenv = tu.batch_len(segs[0]) nseg = len(segs) envs_segs = th.tensor(list(itertools.product(range(nenv), range(nseg)))) for perminds in th.randperm(len(envs_segs)).split(mbsize): esinds = envs_segs[perminds] mb = tu.tree_stack( [tu.tree_slice(segs[segind], envind) for (envind, segind) in esinds] ) for chunknum in range(nchunks): # raises an IndexError if aux_mbsize < 1 is incompatible with nstep yield tree_map(lambda x: th.chunk(x, nchunks, dim=1)[chunknum], mb) def aux_train(*, model, segs, opt, mbsize, name2coef): """ Train on auxiliary loss + policy KL + vf distance """ needed_keys = {"ob", "first", "state_in", "oldpd"}.union(model.aux_keys()) segs = [{k: seg[k] for k in needed_keys} for seg in segs] for mb in make_minibatches(segs, mbsize): mb = tree_map(lambda x: x.to(tu.dev()), mb) pd, _, aux, _state_out = model(mb["ob"], mb["first"], mb["state_in"]) name2loss = {} name2loss["pol_distance"] = td.kl_divergence(mb["oldpd"], pd).mean() name2loss.update(model.compute_aux_loss(aux, mb)) assert set(name2coef.keys()).issubset(name2loss.keys()) loss = 0 for name in name2loss.keys(): unscaled_loss = name2loss[name] scaled_loss = unscaled_loss * name2coef.get(name, 1) logger.logkv_mean("unscaled/" + name, unscaled_loss) logger.logkv_mean("scaled/" + name, scaled_loss) loss += scaled_loss opt.zero_grad() loss.backward() tu.sync_grads(model.parameters()) opt.step() def compute_presleep_outputs( *, model, segs, mbsize, pdkey="oldpd", vpredkey="oldvpred" ): def forward(ob, first, state_in): pd, vpred, _aux, _state_out = model.forward(ob.to(tu.dev()), first, state_in) return pd, vpred for seg in segs: seg[pdkey], seg[vpredkey] = tu.minibatched_call( forward, mbsize, ob=seg["ob"], first=seg["first"], state_in=seg["state_in"] ) def learn( *, model, venv, ppo_hps, aux_lr, aux_mbsize, aux_beta1=0.9, aux_beta2=0.999, n_aux_epochs=6, n_pi=32, kl_ewma_decay=None, interacts_total=float("inf"), name2coef=None, comm=None, ): """ Run PPO for X iterations Then minimize aux loss + KL + value distance for X passes over data """ if comm is None: comm = MPI.COMM_WORLD ppo_state = None aux_state = th.optim.Adam(model.parameters(), lr=aux_lr, betas=(aux_beta1, aux_beta2)) name2coef = name2coef or {} if kl_ewma_decay is not None: model_ewma = EwmaModel(model, kl_ewma_decay) else: model_ewma = None nstep = ppo_hps.get("nstep", 256) while True: store_segs = n_pi != 0 and n_aux_epochs != 0 # Policy phase ppo_state = ppo.learn( venv=venv, model=model, model_ewma=model_ewma, learn_state=ppo_state, callbacks=[ lambda _l: n_pi > 0 and _l["curr_iteration"] >= n_pi, ], interacts_total=interacts_total, store_segs=store_segs, comm=comm, **ppo_hps, ) if ppo_state["curr_interact_count"] >= interacts_total: break if n_aux_epochs > 0: segs = ppo_state["seg_buf"] compute_presleep_outputs(model=model, segs=segs, mbsize=max(1, aux_mbsize)) # Auxiliary phase for i in range(n_aux_epochs): logger.log(f"Aux epoch {i}") aux_train( model=model, segs=segs, opt=aux_state, mbsize=aux_mbsize, name2coef=name2coef, ) logger.dumpkvs() segs.clear() if model_ewma is not None: model_ewma.reset()