in ppo_ewma/ppg.py [0:0]
def aux_train(*, model, segs, opt, mbsize, name2coef):
"""
Train on auxiliary loss + policy KL + vf distance
"""
needed_keys = {"ob", "first", "state_in", "oldpd"}.union(model.aux_keys())
segs = [{k: seg[k] for k in needed_keys} for seg in segs]
for mb in make_minibatches(segs, mbsize):
mb = tree_map(lambda x: x.to(tu.dev()), mb)
pd, _, aux, _state_out = model(mb["ob"], mb["first"], mb["state_in"])
name2loss = {}
name2loss["pol_distance"] = td.kl_divergence(mb["oldpd"], pd).mean()
name2loss.update(model.compute_aux_loss(aux, mb))
assert set(name2coef.keys()).issubset(name2loss.keys())
loss = 0
for name in name2loss.keys():
unscaled_loss = name2loss[name]
scaled_loss = unscaled_loss * name2coef.get(name, 1)
logger.logkv_mean("unscaled/" + name, unscaled_loss)
logger.logkv_mean("scaled/" + name, scaled_loss)
loss += scaled_loss
opt.zero_grad()
loss.backward()
tu.sync_grads(model.parameters())
opt.step()