def step()

in aiops/ContraLSP/abstudy/gatemasknn_no_both.py [0:0]


    def step(self, batch, batch_idx, stage):
        # x is the data to be perturbed
        # y is the same data without perturbation
        x, y, baselines, target, *additional_forward_args = batch

        # If additional_forward_args is only one None,
        # set it to None
        if additional_forward_args == [None]:
            additional_forward_args = None

        # Get perturbed output
        # y_hat1 is computed by masking important features
        # y_hat2 is computed by masking unimportant features
        if additional_forward_args is None:
            y_hat1, y_hat2 = self(x.float(), batch_idx, baselines, target)
        else:
            y_hat1, y_hat2 = self(
                x.float(),
                batch_idx,
                baselines,
                target,
                *additional_forward_args,
            )

        # Get unperturbed output for inputs and baselines
        y_target1 = _run_forward(
            forward_func=self.net.forward_func,
            inputs=y,
            target=target,
            additional_forward_args=tuple(additional_forward_args)
            if additional_forward_args is not None
            else None,
        )

        # Add L1 loss
        mask_ = self.net.mask[
                self.net.batch_size
                * batch_idx: self.net.batch_size * (batch_idx + 1)
                ]
        reg = 0.5 + 0.5 * th.erf(self.net.refactor_mask(mask_, x) / (self.net.sigma * np.sqrt(2)))
        # reg = self.net.refactor_mask(mask_, x).abs()

        # trend_reg = self.net.trend_info(x).abs().mean()
        mask_loss = self.lambda_1 * reg.mean()
        # mask_loss = self.lambda_1 * th.sum(reg, dim=[1,2]).mean()

        triplet_loss = 0
        if self.net.model is not None:
            condition = self.net.model(x - baselines)
            triplet_loss = self.lambda_2 * condition.abs().mean()

        # Add preservation and deletion losses if required
        if self.preservation_mode:
            main_loss = self.loss(y_hat1, y_target1)
        else:
            main_loss = -1. * self.loss(y_hat2, y_target1)

        loss = main_loss + mask_loss + triplet_loss

        # test log
        _test_mask = self.net.representation(x)
        test = (_test_mask[_test_mask > 0]).sum()
        lambda_1_t = self.lambda_1
        lambda_2_t = self.lambda_2
        reg_t = reg.mean()
        print("both", test, reg_t, triplet_loss, lambda_1_t, lambda_2_t, main_loss)

        return loss