in domainbed/algorithms.py [0:0]
def update(self, minibatches):
loss_swap = 0.0
trm = 0.0
if self.update_count >= self.hparams['iters']:
# TRM
if self.hparams['class_balanced']:
# for stability when facing unbalanced labels across environments
for classifier in self.clist:
classifier.weight.data = copy.deepcopy(self.classifier.weight.data)
self.alpha /= self.alpha.sum(1, keepdim=True)
self.featurizer.train()
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
all_feature = self.featurizer(all_x)
# updating original network
loss = F.cross_entropy(self.classifier(all_feature), all_y)
for i in range(30):
all_logits_idx = 0
loss_erm = 0.
for j, (x, y) in enumerate(minibatches):
# j-th domain
feature = all_feature[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
loss_erm += F.cross_entropy(self.clist[j](feature.detach()), y)
for opt in self.olist:
opt.zero_grad()
loss_erm.backward()
for opt in self.olist:
opt.step()
# collect (feature, y)
feature_split = list()
y_split = list()
all_logits_idx = 0
for i, (x, y) in enumerate(minibatches):
feature = all_feature[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
feature_split.append(feature)
y_split.append(y)
# estimate transfer risk
for Q, (x, y) in enumerate(minibatches):
sample_list = list(range(len(minibatches)))
sample_list.remove(Q)
loss_Q = F.cross_entropy(self.clist[Q](feature_split[Q]), y_split[Q])
grad_Q = autograd.grad(loss_Q, self.clist[Q].weight, create_graph=True)
vec_grad_Q = nn.utils.parameters_to_vector(grad_Q)
loss_P = [F.cross_entropy(self.clist[Q](feature_split[i]), y_split[i])*(self.alpha[Q, i].data.detach())
if i in sample_list else 0. for i in range(len(minibatches))]
loss_P_sum = sum(loss_P)
grad_P = autograd.grad(loss_P_sum, self.clist[Q].weight, create_graph=True)
vec_grad_P = nn.utils.parameters_to_vector(grad_P).detach()
vec_grad_P = self.neum(vec_grad_P, self.clist[Q], (feature_split[Q], y_split[Q]))
loss_swap += loss_P_sum - self.hparams['cos_lambda'] * (vec_grad_P.detach() @ vec_grad_Q)
for i in sample_list:
self.alpha[Q, i] *= (self.hparams["groupdro_eta"] * loss_P[i].data).exp()
loss_swap /= len(minibatches)
trm /= len(minibatches)
else:
# ERM
self.featurizer.train()
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
all_feature = self.featurizer(all_x)
loss = F.cross_entropy(self.classifier(all_feature), all_y)
nll = loss.item()
self.optimizer_c.zero_grad()
self.optimizer_f.zero_grad()
if self.update_count >= self.hparams['iters']:
loss_swap = (loss + loss_swap)
else:
loss_swap = loss
loss_swap.backward()
self.optimizer_f.step()
self.optimizer_c.step()
loss_swap = loss_swap.item() - nll
self.update_count += 1
return {'nll': nll, 'trm_loss': loss_swap}