in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for _, y in minibatches])
lam = np.random.beta(0.5, 0.5)
batch_size = all_y.size()[0]
# cluster and order features into same-class group
with torch.no_grad():
sorted_y, indices = torch.sort(all_y)
sorted_x = torch.zeros_like(all_x)
for idx, order in enumerate(indices):
sorted_x[idx] = all_x[order]
intervals = []
ex = 0
for idx, val in enumerate(sorted_y):
if ex==val:
continue
intervals.append(idx)
ex = val
intervals.append(batch_size)
all_x = sorted_x
all_y = sorted_y
feat = self.featurizer(all_x)
proj = self.cdpl(feat)
output = self.classifier(feat)
# shuffle
output_2 = torch.zeros_like(output)
feat_2 = torch.zeros_like(proj)
output_3 = torch.zeros_like(output)
feat_3 = torch.zeros_like(proj)
ex = 0
for end in intervals:
shuffle_indices = torch.randperm(end-ex)+ex
shuffle_indices2 = torch.randperm(end-ex)+ex
for idx in range(end-ex):
output_2[idx+ex] = output[shuffle_indices[idx]]
feat_2[idx+ex] = proj[shuffle_indices[idx]]
output_3[idx+ex] = output[shuffle_indices2[idx]]
feat_3[idx+ex] = proj[shuffle_indices2[idx]]
ex = end
# mixup
output_3 = lam*output_2 + (1-lam)*output_3
feat_3 = lam*feat_2 + (1-lam)*feat_3
# regularization
L_ind_logit = self.MSEloss(output, output_2)
L_hdl_logit = self.MSEloss(output, output_3)
L_ind_feat = 0.3 * self.MSEloss(feat, feat_2)
L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3)
cl_loss = F.cross_entropy(output, all_y)
C_scale = min(cl_loss.item(), 1.)
loss = cl_loss + C_scale*(lam*(L_ind_logit + L_ind_feat)+(1-lam)*(L_hdl_logit + L_hdl_feat))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {'loss': loss.item()}