def forward_backward()

in Dassl.pytorch/dassl/engine/da/cdac.py [0:0]


    def forward_backward(self, batch_x, batch_u):

        current_itr = self.epoch * self.num_batches + self.batch_idx

        input_x, label_x, input_u, input_us, input_us2, label_u = self.parse_batch_train(
            batch_x, batch_u
        )

        # Paper Reference Eq. 2 - Supervised Loss

        feat_x = self.F(input_x)
        logit_x = self.C(feat_x)
        loss_x = F.cross_entropy(logit_x, label_x)

        self.model_backward_and_update(loss_x)

        feat_u = self.F(input_u)
        feat_us = self.F(input_us)
        feat_us2 = self.F(input_us2)

        # Paper Reference Eq.3 - Adversarial Adaptive Loss
        logit_u = self.C(feat_u, reverse=True)
        logit_us = self.C(feat_us, reverse=True)
        prob_u, prob_us = F.softmax(logit_u, dim=1), F.softmax(logit_us, dim=1)

        # Get similarity matrix s_ij
        sim_mat = self.get_similarity_matrix(feat_u, self.topk, self.device)

        aac_loss = (-1. * self.aac_criterion(sim_mat, prob_u, prob_us))

        # Paper Reference Eq. 4 - Pseudo label Loss
        logit_u = self.C(feat_u)
        logit_us = self.C(feat_us)
        logit_us2 = self.C(feat_us2)
        prob_u, prob_us, prob_us2 = F.softmax(
            logit_u, dim=1
        ), F.softmax(
            logit_us, dim=1
        ), F.softmax(
            logit_us2, dim=1
        )
        prob_u = prob_u.detach()
        max_probs, max_idx = torch.max(prob_u, dim=-1)
        mask = max_probs.ge(self.p_thresh).float()
        p_u_stats = self.assess_y_pred_quality(max_idx, label_u, mask)

        pl_loss = (
            F.cross_entropy(logit_us2, max_idx, reduction='none') * mask
        ).mean()

        # Paper Reference Eq. 8 - Consistency Loss
        cons_multi = self.sigmoid_rampup(
            current_itr=current_itr, rampup_itr=self.rampup_iters
        ) * self.rampup_coef
        cons_loss = cons_multi * F.mse_loss(prob_us, prob_us2)

        loss_u = aac_loss + pl_loss + cons_loss

        self.model_backward_and_update(loss_u)

        loss_summary = {
            "loss_x": loss_x.item(),
            "acc_x": compute_accuracy(logit_x, label_x)[0].item(),
            "loss_u": loss_u.item(),
            "aac_loss": aac_loss.item(),
            "pl_loss": pl_loss.item(),
            "cons_loss": cons_loss.item(),
            "p_u_pred_acc": p_u_stats["acc_raw"],
            "p_u_pred_acc_thre": p_u_stats["acc_thre"],
            "p_u_pred_keep": p_u_stats["keep_rate"]
        }

        # Update LR after every iteration as mentioned in the paper

        self.update_lr()

        return loss_summary