def aux_train()

in ppo_ewma/ppg.py [0:0]


def aux_train(*, model, segs, opt, mbsize, name2coef):
    """
    Train on auxiliary loss + policy KL + vf distance
    """
    needed_keys = {"ob", "first", "state_in", "oldpd"}.union(model.aux_keys())
    segs = [{k: seg[k] for k in needed_keys} for seg in segs]
    for mb in make_minibatches(segs, mbsize):
        mb = tree_map(lambda x: x.to(tu.dev()), mb)
        pd, _, aux, _state_out = model(mb["ob"], mb["first"], mb["state_in"])
        name2loss = {}
        name2loss["pol_distance"] = td.kl_divergence(mb["oldpd"], pd).mean()
        name2loss.update(model.compute_aux_loss(aux, mb))
        assert set(name2coef.keys()).issubset(name2loss.keys())
        loss = 0
        for name in name2loss.keys():
            unscaled_loss = name2loss[name]
            scaled_loss = unscaled_loss * name2coef.get(name, 1)
            logger.logkv_mean("unscaled/" + name, unscaled_loss)
            logger.logkv_mean("scaled/" + name, scaled_loss)
            loss += scaled_loss
        opt.zero_grad()
        loss.backward()
        tu.sync_grads(model.parameters())
        opt.step()