def forward_backward()

in Dassl.pytorch/dassl/engine/dg/ddaig.py [0:0]


    def forward_backward(self, batch):
        input, label, domain = self.parse_batch_train(batch)

        #############
        # Update G
        #############
        input_p = self.G(input, lmda=self.lmda)
        if self.clamp:
            input_p = torch.clamp(
                input_p, min=self.clamp_min, max=self.clamp_max
            )
        loss_g = 0
        # Minimize label loss
        loss_g += F.cross_entropy(self.F(input_p), label)
        # Maximize domain loss
        loss_g -= F.cross_entropy(self.D(input_p), domain)
        self.model_backward_and_update(loss_g, "G")

        # Perturb data with new G
        with torch.no_grad():
            input_p = self.G(input, lmda=self.lmda)
            if self.clamp:
                input_p = torch.clamp(
                    input_p, min=self.clamp_min, max=self.clamp_max
                )

        #############
        # Update F
        #############
        loss_f = F.cross_entropy(self.F(input), label)
        if (self.epoch + 1) > self.warmup:
            loss_fp = F.cross_entropy(self.F(input_p), label)
            loss_f = (1.0 - self.alpha) * loss_f + self.alpha * loss_fp
        self.model_backward_and_update(loss_f, "F")

        #############
        # Update D
        #############
        loss_d = F.cross_entropy(self.D(input), domain)
        self.model_backward_and_update(loss_d, "D")

        loss_summary = {
            "loss_g": loss_g.item(),
            "loss_f": loss_f.item(),
            "loss_d": loss_d.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary