in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
if self.update_count >= self.hparams["vrex_penalty_anneal_iters"]:
penalty_weight = self.hparams["vrex_lambda"]
else:
penalty_weight = 1.0
nll = 0.
all_x = torch.cat([x for x, y in minibatches])
all_logits = self.network(all_x)
all_logits_idx = 0
losses = torch.zeros(len(minibatches))
for i, (x, y) in enumerate(minibatches):
logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
nll = F.cross_entropy(logits, y)
losses[i] = nll
mean = losses.mean()
penalty = ((losses - mean) ** 2).mean()
loss = mean + penalty_weight * penalty
if self.update_count == self.hparams['vrex_penalty_anneal_iters']:
# Reset Adam (like IRM), because it doesn't like the sharp jump in
# gradient magnitudes that happens at this step.
self.optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay'])
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.update_count += 1
return {'loss': loss.item(), 'nll': nll.item(),
'penalty': penalty.item()}