aiops/ContraLSP/demo.py (133 lines of code) (raw):
import torch
import numpy as np
import random as rd
import argparse
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import Trainer, seed_everything
from tint.attr import (
ExtremalMask,
DynaMask,
)
from tint.attr.models import (
ExtremalMaskNet,
MaskNet,
)
from tint.models import MLP, RNN
from utils.tools import print_results, plot_example_box
from attribution.gatemasknn import *
from attribution.gate_mask import *
rd.seed(42)
np.random.seed(42)
torch.manual_seed(42)
seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Unit Testing')
parser.add_argument('--feature_num', default=3, type=int)
parser.add_argument('--ts', default=10*2, type=int)
parser.add_argument('--bs', default=100, type=int)
args = parser.parse_args()
explainers = ["gatemask","extrmask", "dynamask"]
X = torch.randint(low=0, high=5, size=(args.bs, args.ts, args.feature_num)).float()
def black_box(input):
num_samples, time_len, num_f = input.shape
output = torch.zeros((num_samples, time_len, 1)) # (bs, T, 1)
# level 1
output[:25, 4*2:9*2, :] = input[:25, 0*2:5*2, 0:1]
# level 2
output[25:50, 4*2:, :] = input[25:50, 4*2:, 1:2]
# level 3
output[50:, 0*2:2*2, :] = (input[50:, 0*2:2*2, 0:1] + input[50:, 0*2:2*2, 2:])**2
return output.reshape(-1, time_len, 1)
true_saliency = torch.zeros(X.shape)
true_saliency[:25, 0*2:5*2, 0:1] = 1
true_saliency[25:50, 4*2:, 1:2] = 1
true_saliency[50:, 0*2:2*2, 0:1], true_saliency[50:, 0*2:2*2, 2:] = 1, 1
print("-===============", true_saliency.sum())
for i in range(args.bs):
plot_example_box(true_saliency, i, "./plot/demo2plot/true_{}.png".format(i))
if "gatemask" in explainers:
trainer = Trainer(
max_epochs=50,
accelerator="cpu",
log_every_n_steps=2,
callbacks=[EarlyStopping('train_loss', patience=10, mode='min')],
)
mask = GateMaskNet(
forward_func=black_box,
model=nn.Sequential(
RNN(
input_size=X.shape[-1],
rnn="gru",
hidden_size=X.shape[-1],
bidirectional=True,
),
MLP([2 * X.shape[-1], X.shape[-1]]),
),
lambda_1=0.1, # 0.1 for our lambda is suitable
lambda_2=0.1,
optim="adam",
lr=0.1,
)
explainer = GateMask(black_box)
_attr = explainer.attribute(
X,
trainer=trainer,
mask_net=mask,
batch_size=args.bs,
win_size=5,
sigma=0.5,
)
gatemask_saliency = _attr.clone().detach()
print_results(gatemask_saliency, true_saliency)
# plot_example_box(gatemask_saliency, 0)
# plot_example_box(gatemask_saliency, 49)
# plot_example_box(gatemask_saliency, 99)
for i in range(args.bs):
plot_example_box(gatemask_saliency, i, "./plot/demo2plot/gatemask_{}.png".format(i))
if "extrmask" in explainers:
trainer = Trainer(
max_epochs=50,
accelerator='cpu',
log_every_n_steps=2,
callbacks=[EarlyStopping('train_loss', patience=10, mode='min')],
)
mask = ExtremalMaskNet(
forward_func=black_box,
model=nn.Sequential(
RNN(
input_size=X.shape[-1],
rnn="gru",
hidden_size=X.shape[-1],
bidirectional=True,
),
MLP([2 * X.shape[-1], X.shape[-1]]),
),
optim="adam",
lr=0.1,
)
explainer = ExtremalMask(black_box)
_attr = explainer.attribute(
X,
trainer=trainer,
mask_net=mask,
batch_size=args.bs,
)
nnmask_saliency = _attr.clone().detach().numpy()
print_results(nnmask_saliency, true_saliency)
# plot_example_box(nnmask_saliency, 0)
# plot_example_box(nnmask_saliency, 49)
# plot_example_box(nnmask_saliency, 99)
for i in range(args.bs):
plot_example_box(nnmask_saliency, i, "./plot/demo2plot/extrmask_{}.png".format(i))
if "dynamask" in explainers:
from attribution.mask_group import MaskGroup
from attribution.perturbation import GaussianBlur
from utils.losses import mse_multiple
pert = GaussianBlur(device=device) # We use a Gaussian Blur perturbation operator
mask_group = MaskGroup(perturbation=pert, device=device, random_seed=42)
mask_group.fit_multiple(
f=black_box,
X=X,
area_list=np.arange(0.01, 0.51, 0.01),
loss_function_multiple=mse_multiple,
n_epoch=50,
learning_rate=0.1,
)
thresh = 0.01 * torch.ones(args.bs)
mask = mask_group.get_extremal_mask_multiple(thresh) # The mask with the lowest error is selected
dynamask_saliency = mask.clone().detach().numpy()
print_results(dynamask_saliency, true_saliency)
# plot_example_box(dynamask_saliency, 0)
# plot_example_box(dynamask_saliency, 49)
# plot_example_box(dynamask_saliency, 99)
#
for i in range(args.bs):
plot_example_box(dynamask_saliency, i, "./plot/demo2plot/dynamask_{}.png".format(i))