in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
objective = 0
penalty = 0
nmb = len(minibatches)
features = [self.featurizer(xi) for xi, _ in minibatches]
classifs = [self.classifier(fi) for fi in features]
targets = [yi for _, yi in minibatches]
for i in range(nmb):
objective += F.cross_entropy(classifs[i], targets[i])
for j in range(i + 1, nmb):
penalty += self.mmd(features[i], features[j])
objective /= nmb
if nmb > 1:
penalty /= (nmb * (nmb - 1) / 2)
self.optimizer.zero_grad()
(objective + (self.hparams['mmd_gamma']*penalty)).backward()
self.optimizer.step()
if torch.is_tensor(penalty):
penalty = penalty.item()
return {'loss': objective.item(), 'penalty': penalty}