def fit()

in scripts/models.py [0:0]


    def fit(self, envs, num_iterations, callback=False):
        for epoch in range(num_iterations):
            losses = [self.loss(self.network(x), y)
                      for x, y in envs["train"]["envs"]]
            gradients = [
                grad(loss, self.parameters(), create_graph=True)
                for loss in losses
            ]
            # average loss and gradients
            avg_loss = sum(losses) / len(losses)
            avg_gradient = grad(avg_loss, self.parameters(), create_graph=True)

            # compute trace penalty
            penalty_value = 0
            for gradient in gradients:
                for gradient_i, avg_grad_i in zip(gradient, avg_gradient):
                    penalty_value += (gradient_i - avg_grad_i).pow(2).sum()

            self.optimizer.zero_grad()
            (avg_loss + self.hparams['penalty'] * penalty_value).backward()
            self.optimizer.step()

            if callback:
                # compute errors
                utils.compute_errors(self, envs)