in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=False):
assert len(minibatches) == self.num_domains
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
len_minibatches = [x.shape[0] for x, y in minibatches]
all_z = self.featurizer(all_x)
all_logits = self.classifier(all_z)
penalty = self.compute_fishr_penalty(all_logits, all_y, len_minibatches)
all_nll = F.cross_entropy(all_logits, all_y)
penalty_weight = 0
if self.update_count >= self.hparams["penalty_anneal_iters"]:
penalty_weight = self.hparams["lambda"]
if self.update_count == self.hparams["penalty_anneal_iters"] != 0:
# Reset Adam as in IRM or V-REx, because it may not like the sharp jump in
# gradient magnitudes that happens at this step.
self._init_optimizer()
self.update_count += 1
objective = all_nll + penalty_weight * penalty
self.optimizer.zero_grad()
objective.backward()
self.optimizer.step()
return {'loss': objective.item(), 'nll': all_nll.item(), 'penalty': penalty.item()}