def forward_backward()

in Dassl.pytorch/dassl/engine/ssl/mixmatch.py [0:0]


    def forward_backward(self, batch_x, batch_u):
        input_x, label_x, input_u = self.parse_batch_train(batch_x, batch_u)
        num_x = input_x.shape[0]

        global_step = self.batch_idx + self.epoch * self.num_batches
        weight_u = self.weight_u * linear_rampup(global_step, self.rampup)

        # Generate pseudo-label for unlabeled data
        with torch.no_grad():
            output_u = 0
            for input_ui in input_u:
                output_ui = F.softmax(self.model(input_ui), 1)
                output_u += output_ui
            output_u /= len(input_u)
            label_u = sharpen_prob(output_u, self.temp)
            label_u = [label_u] * len(input_u)
            label_u = torch.cat(label_u, 0)
            input_u = torch.cat(input_u, 0)

        # Combine and shuffle labeled and unlabeled data
        input_xu = torch.cat([input_x, input_u], 0)
        label_xu = torch.cat([label_x, label_u], 0)
        input_xu, label_xu = shuffle_index(input_xu, label_xu)

        # Mixup
        input_x, label_x = mixup(
            input_x,
            input_xu[:num_x],
            label_x,
            label_xu[:num_x],
            self.beta,
            preserve_order=True,
        )

        input_u, label_u = mixup(
            input_u,
            input_xu[num_x:],
            label_u,
            label_xu[num_x:],
            self.beta,
            preserve_order=True,
        )

        # Compute losses
        output_x = F.softmax(self.model(input_x), 1)
        loss_x = (-label_x * torch.log(output_x + 1e-5)).sum(1).mean()

        output_u = F.softmax(self.model(input_u), 1)
        loss_u = ((label_u - output_u)**2).mean()

        loss = loss_x + loss_u*weight_u
        self.model_backward_and_update(loss)

        loss_summary = {"loss_x": loss_x.item(), "loss_u": loss_u.item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary