def forward_backward()

in Dassl.pytorch/dassl/engine/da/mcd.py [0:0]


    def forward_backward(self, batch_x, batch_u):
        parsed = self.parse_batch_train(batch_x, batch_u)
        input_x, label_x, input_u = parsed

        # Step A
        feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_step_A = loss_x1 + loss_x2
        self.model_backward_and_update(loss_step_A)

        # Step B
        with torch.no_grad():
            feat_x = self.F(input_x)
        logit_x1 = self.C1(feat_x)
        logit_x2 = self.C2(feat_x)
        loss_x1 = F.cross_entropy(logit_x1, label_x)
        loss_x2 = F.cross_entropy(logit_x2, label_x)
        loss_x = loss_x1 + loss_x2

        with torch.no_grad():
            feat_u = self.F(input_u)
        pred_u1 = F.softmax(self.C1(feat_u), 1)
        pred_u2 = F.softmax(self.C2(feat_u), 1)
        loss_dis = self.discrepancy(pred_u1, pred_u2)

        loss_step_B = loss_x - loss_dis
        self.model_backward_and_update(loss_step_B, ["C1", "C2"])

        # Step C
        for _ in range(self.n_step_F):
            feat_u = self.F(input_u)
            pred_u1 = F.softmax(self.C1(feat_u), 1)
            pred_u2 = F.softmax(self.C2(feat_u), 1)
            loss_step_C = self.discrepancy(pred_u1, pred_u2)
            self.model_backward_and_update(loss_step_C, "F")

        loss_summary = {
            "loss_step_A": loss_step_A.item(),
            "loss_step_B": loss_step_B.item(),
            "loss_step_C": loss_step_C.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary