in ppo_ewma/ppo.py [0:0]
def train_with_losses_and_opt(loss_keys, opt, **arrays):
losses, diags = compute_losses(
model,
model_ewma=model_ewma,
entcoef=entcoef,
kl_penalty=kl_penalty,
clip_param=clip_param,
vfcoef=vfcoef,
imp_samp_max=imp_samp_max,
**arrays,
)
loss = sum([losses[k] * get_weight(k) for k in loss_keys])
opt.zero_grad()
loss.backward()
tu.warn_no_gradient(model, "PPO")
tu.sync_grads(params, grad_weight=grad_weight)
diags = {k: v.detach() for (k, v) in diags.items()}
opt.step()
if "pi" in loss_keys and model_ewma is not None:
model_ewma.update()
diags.update({f"loss_{k}": v.detach() for (k, v) in losses.items()})
return diags