in scripts/models.py [0:0]
def fit(self, envs, num_iterations, callback=False):
for epoch in range(num_iterations):
losses_env = []
gradients_env = []
for x, y in envs["train"]["envs"]:
losses_env.append(self.loss(self.network(x), y))
gradients_env.append(grad(
losses_env[-1], self.net_dummies, create_graph=True))
# Average loss across envs
losses_avg = sum(losses_env) / len(losses_env)
gradients_avg = grad(
losses_avg, self.net_dummies, create_graph=True)
penalty = 0
for gradients_this_env in gradients_env:
for g_env, g_avg in zip(gradients_this_env, gradients_avg):
if self.version == 1:
penalty += g_env.pow(2).sum()
else:
raise NotImplementedError
obj = (1 - self.hparams["irm_lambda"]) * losses_avg
obj += self.hparams["irm_lambda"] * penalty
self.optimizer.zero_grad()
obj.backward()
self.optimizer.step()
if callback:
# compute errors
utils.compute_errors(self, envs)