in Dassl.pytorch/dassl/engine/dg/daeldg.py [0:0]
def forward_backward(self, batch):
parsed_data = self.parse_batch_train(batch)
input, input2, label, domain = parsed_data
input = torch.split(input, self.split_batch, 0)
input2 = torch.split(input2, self.split_batch, 0)
label = torch.split(label, self.split_batch, 0)
domain = torch.split(domain, self.split_batch, 0)
domain = [d[0].item() for d in domain]
loss_x = 0
loss_cr = 0
acc = 0
feat = [self.F(x) for x in input]
feat2 = [self.F(x) for x in input2]
for feat_i, feat2_i, label_i, i in zip(feat, feat2, label, domain):
cr_s = [j for j in domain if j != i]
# Learning expert
pred_i = self.E(i, feat_i)
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
expert_label_i = pred_i.detach()
acc += compute_accuracy(pred_i.detach(),
label_i.max(1)[1])[0].item()
# Consistency regularization
cr_pred = []
for j in cr_s:
pred_j = self.E(j, feat2_i)
pred_j = pred_j.unsqueeze(1)
cr_pred.append(pred_j)
cr_pred = torch.cat(cr_pred, 1)
cr_pred = cr_pred.mean(1)
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
loss_x /= self.n_domain
loss_cr /= self.n_domain
acc /= self.n_domain
loss = 0
loss += loss_x
loss += loss_cr
self.model_backward_and_update(loss)
loss_summary = {
"loss_x": loss_x.item(),
"acc": acc,
"loss_cr": loss_cr.item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary