in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
device = "cuda" if minibatches[0][0].is_cuda else "cpu"
penalty_weight = (self.hparams['irm_lambda'] if self.update_count
>= self.hparams['irm_penalty_anneal_iters'] else
1.0)
nll = 0.
penalty = 0.
all_x = torch.cat([x for x,y in minibatches])
all_logits = self.network(all_x)
all_logits_idx = 0
for i, (x, y) in enumerate(minibatches):
logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
nll += F.cross_entropy(logits, y)
penalty += self._irm_penalty(logits, y)
nll /= len(minibatches)
penalty /= len(minibatches)
loss = nll + (penalty_weight * penalty)
if self.update_count == self.hparams['irm_penalty_anneal_iters']:
# Reset Adam, because it doesn't like the sharp jump in gradient
# magnitudes that happens at this step.
self.optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay'])
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.update_count += 1
return {'loss': loss.item(), 'nll': nll.item(),
'penalty': penalty.item()}