in scripts/models.py [0:0]
def fit(self, envs, num_iterations, callback=False):
for epoch in range(num_iterations):
losses = [self.loss(self.network(x), y)
for x, y in envs["train"]["envs"]]
gradients = [
grad(loss, self.parameters(), create_graph=True)
for loss in losses
]
# average loss and gradients
avg_loss = sum(losses) / len(losses)
avg_gradient = grad(avg_loss, self.parameters(), create_graph=True)
# compute trace penalty
penalty_value = 0
for gradient in gradients:
for gradient_i, avg_grad_i in zip(gradient, avg_gradient):
penalty_value += (gradient_i - avg_grad_i).pow(2).sum()
self.optimizer.zero_grad()
(avg_loss + self.hparams['penalty'] * penalty_value).backward()
self.optimizer.step()
if callback:
# compute errors
utils.compute_errors(self, envs)