in Dassl.pytorch/dassl/engine/da/dael.py [0:0]
def forward_backward(self, batch_x, batch_u):
parsed_data = self.parse_batch_train(batch_x, batch_u)
input_x, input_x2, label_x, domain_x, input_u, input_u2 = parsed_data
input_x = torch.split(input_x, self.split_batch, 0)
input_x2 = torch.split(input_x2, self.split_batch, 0)
label_x = torch.split(label_x, self.split_batch, 0)
domain_x = torch.split(domain_x, self.split_batch, 0)
domain_x = [d[0].item() for d in domain_x]
# Generate pseudo label
with torch.no_grad():
feat_u = self.F(input_u)
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1) # (B, K, C)
# Get the highest probability and index (label) for each expert
experts_max_p, experts_max_idx = pred_u.max(2) # (B, K)
# Get the most confident expert
max_expert_p, max_expert_idx = experts_max_p.max(1) # (B)
pseudo_label_u = []
for i, experts_label in zip(max_expert_idx, experts_max_idx):
pseudo_label_u.append(experts_label[i])
pseudo_label_u = torch.stack(pseudo_label_u, 0)
pseudo_label_u = create_onehot(pseudo_label_u, self.num_classes)
pseudo_label_u = pseudo_label_u.to(self.device)
label_u_mask = (max_expert_p >= self.conf_thre).float()
loss_x = 0
loss_cr = 0
acc_x = 0
feat_x = [self.F(x) for x in input_x]
feat_x2 = [self.F(x) for x in input_x2]
feat_u2 = self.F(input_u2)
for feat_xi, feat_x2i, label_xi, i in zip(
feat_x, feat_x2, label_x, domain_x
):
cr_s = [j for j in domain_x if j != i]
# Learning expert
pred_xi = self.E(i, feat_xi)
loss_x += (-label_xi * torch.log(pred_xi + 1e-5)).sum(1).mean()
expert_label_xi = pred_xi.detach()
acc_x += compute_accuracy(pred_xi.detach(),
label_xi.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat_x2i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_xi)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc_x /= self.n_domain
# Unsupervised loss
pred_u = []
for k in range(self.num_source_domains):
pred_uk = self.E(k, feat_u2)
pred_uk = pred_uk.unsqueeze(1)
pred_u.append(pred_uk)
pred_u = torch.cat(pred_u, 1)
pred_u = pred_u.mean(1)
l_u = (-pseudo_label_u * torch.log(pred_u + 1e-5)).sum(1)
loss_u = (l_u * label_u_mask).mean()
loss = 0
loss += loss_x
loss += loss_cr
loss += loss_u * self.weight_u
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc_x": acc_x,
"loss_cr": loss_cr.item(),
"loss_u": loss_u.item(),
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary