in aiops/ContraLSP/abstudy/gatemasknn_no_both.py [0:0]
def step(self, batch, batch_idx, stage):
# x is the data to be perturbed
# y is the same data without perturbation
x, y, baselines, target, *additional_forward_args = batch
# If additional_forward_args is only one None,
# set it to None
if additional_forward_args == [None]:
additional_forward_args = None
# Get perturbed output
# y_hat1 is computed by masking important features
# y_hat2 is computed by masking unimportant features
if additional_forward_args is None:
y_hat1, y_hat2 = self(x.float(), batch_idx, baselines, target)
else:
y_hat1, y_hat2 = self(
x.float(),
batch_idx,
baselines,
target,
*additional_forward_args,
)
# Get unperturbed output for inputs and baselines
y_target1 = _run_forward(
forward_func=self.net.forward_func,
inputs=y,
target=target,
additional_forward_args=tuple(additional_forward_args)
if additional_forward_args is not None
else None,
)
# Add L1 loss
mask_ = self.net.mask[
self.net.batch_size
* batch_idx: self.net.batch_size * (batch_idx + 1)
]
reg = 0.5 + 0.5 * th.erf(self.net.refactor_mask(mask_, x) / (self.net.sigma * np.sqrt(2)))
# reg = self.net.refactor_mask(mask_, x).abs()
# trend_reg = self.net.trend_info(x).abs().mean()
mask_loss = self.lambda_1 * reg.mean()
# mask_loss = self.lambda_1 * th.sum(reg, dim=[1,2]).mean()
triplet_loss = 0
if self.net.model is not None:
condition = self.net.model(x - baselines)
triplet_loss = self.lambda_2 * condition.abs().mean()
# Add preservation and deletion losses if required
if self.preservation_mode:
main_loss = self.loss(y_hat1, y_target1)
else:
main_loss = -1. * self.loss(y_hat2, y_target1)
loss = main_loss + mask_loss + triplet_loss
# test log
_test_mask = self.net.representation(x)
test = (_test_mask[_test_mask > 0]).sum()
lambda_1_t = self.lambda_1
lambda_2_t = self.lambda_2
reg_t = reg.mean()
print("both", test, reg_t, triplet_loss, lambda_1_t, lambda_2_t, main_loss)
return loss