def fit()

in scripts/models.py [0:0]


    def fit(self, envs, num_iterations, callback=False):
        for epoch in range(num_iterations):
            losses_env = []
            gradients_env = []
            for x, y in envs["train"]["envs"]:
                losses_env.append(self.loss(self.network(x), y))
                gradients_env.append(grad(
                    losses_env[-1], self.net_dummies, create_graph=True))

            # Average loss across envs
            losses_avg = sum(losses_env) / len(losses_env)
            gradients_avg = grad(
                losses_avg, self.net_dummies, create_graph=True)

            penalty = 0
            for gradients_this_env in gradients_env:
                for g_env, g_avg in zip(gradients_this_env, gradients_avg):
                    if self.version == 1:
                        penalty += g_env.pow(2).sum()
                    else:
                        raise NotImplementedError

            obj = (1 - self.hparams["irm_lambda"]) * losses_avg
            obj += self.hparams["irm_lambda"] * penalty

            self.optimizer.zero_grad()
            obj.backward()
            self.optimizer.step()

            if callback:
                # compute errors
                utils.compute_errors(self, envs)