in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
device = "cuda" if minibatches[0][0].is_cuda else "cpu"
self.update_count += 1
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
all_z = self.featurizer(all_x)
if self.conditional:
disc_input = all_z + self.class_embeddings(all_y)
else:
disc_input = all_z
disc_out = self.discriminator(disc_input)
disc_labels = torch.cat([
torch.full((x.shape[0], ), i, dtype=torch.int64, device=device)
for i, (x, y) in enumerate(minibatches)
])
if self.class_balance:
y_counts = F.one_hot(all_y).sum(dim=0)
weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
disc_loss = (weights * disc_loss).sum()
else:
disc_loss = F.cross_entropy(disc_out, disc_labels)
disc_softmax = F.softmax(disc_out, dim=1)
input_grad = autograd.grad(disc_softmax[:, disc_labels].sum(),
[disc_input], create_graph=True)[0]
grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
disc_loss += self.hparams['grad_penalty'] * grad_penalty
d_steps_per_g = self.hparams['d_steps_per_g_step']
if (self.update_count.item() % (1+d_steps_per_g) < d_steps_per_g):
self.disc_opt.zero_grad()
disc_loss.backward()
self.disc_opt.step()
return {'disc_loss': disc_loss.item()}
else:
all_preds = self.classifier(all_z)
classifier_loss = F.cross_entropy(all_preds, all_y)
gen_loss = (classifier_loss +
(self.hparams['lambda'] * -disc_loss))
self.disc_opt.zero_grad()
self.gen_opt.zero_grad()
gen_loss.backward()
self.gen_opt.step()
return {'gen_loss': gen_loss.item()}