aiops/ContraLSP/attribution/gate_mask.py (81 lines of code) (raw):

import copy import torch as th from captum.attr._utils.attribution import PerturbationAttribution from captum.log import log_usage from captum._utils.common import ( _format_baseline, _format_inputs, _format_output, _is_tuple, _validate_input, ) from captum._utils.typing import ( BaselineType, TargetType, TensorOrTupleOfTensorsGeneric, ) from pytorch_lightning import Trainer from torch.utils.data import DataLoader from typing import Any, Callable, Tuple from tint.utils import TensorDataset, _add_temporal_mask, default_collate from attribution.gatemasknn import GateMaskNet class GateMask(PerturbationAttribution): """ Extremal masks. This method extends the work of Fong et al. and Crabbé et al. by allowing the perturbation function to be learnt. This is in addition to the learnt mask. For instance, this perturbation function can be learnt with a RNN while Crabbé et al. only consider fixed perturbations: Gaussian blur and fade to moving average. Args: forward_func (callable): The forward function of the model or any modification of it. References: #. `Learning Perturbations to Explain Time Series Predictions <https://arxiv.org/abs/2305.18840>`_ #. `Understanding Deep Networks via Extremal Perturbations and Smooth Masks <https://arxiv.org/abs/1910.08485>`_ Examples: >>> import torch as th >>> from tint.attr import ExtremalMask >>> from tint.models import MLP <BLANKLINE> >>> inputs = th.rand(8, 7, 5) >>> data = th.rand(32, 7, 5) >>> mlp = MLP([5, 3, 1]) <BLANKLINE> >>> explainer = ExtremalMask(mlp) >>> attr = explainer.attribute(inputs) """ def __init__(self, forward_func: Callable) -> None: super().__init__(forward_func=forward_func) @log_usage() def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None, target: TargetType = None, additional_forward_args: Any = None, trainer: Trainer = None, mask_net: GateMaskNet = None, batch_size: int = 32, temporal_additional_forward_args: Tuple[bool] = None, return_temporal_attributions: bool = False, win_size: int = 5, sigma: float = 0.5, ) -> TensorOrTupleOfTensorsGeneric: # Keeps track whether original input is a tuple or not before # converting it into a tuple. is_inputs_tuple = _is_tuple(inputs) inputs = _format_inputs(inputs) # Format and validate baselines baselines = _format_baseline(baselines, inputs) _validate_input(inputs, baselines) # Init trainer if not provided if trainer is None: trainer = Trainer(max_epochs=100) else: trainer = copy.deepcopy(trainer) # Assert only one input, as the Retain only accepts one assert ( len(inputs) == 1 ), "Multiple inputs are not accepted for this method" data = inputs[0] baseline = baselines[0] # If return temporal attr, we expand the input data # and multiply it with a lower triangular mask if return_temporal_attributions: data, additional_forward_args, _ = _add_temporal_mask( inputs=data, additional_forward_args=additional_forward_args, temporal_additional_forward_args=temporal_additional_forward_args, ) # Init MaskNet if not provided if mask_net is None: mask_net = GateMaskNet(forward_func=self.forward_func) # Init model mask_net.net.init(input_size=data.shape, batch_size=batch_size, win_size=win_size, sigma=sigma, n_epochs=trainer.max_epochs) # Prepare data dataloader = DataLoader( TensorDataset( *(data, data, baseline, target, *additional_forward_args) if additional_forward_args is not None else (data, data, baseline, target, None) ), batch_size=batch_size, collate_fn=default_collate, ) # Fit model trainer.fit(mask_net, train_dataloaders=dataloader) # Set model to eval mode and cast it to device mask_net.eval() mask_net.to(data.device) # Get attributions as mask representation attributions = mask_net.net.representation(data) # self.learn_sig = mask_net.net.refactor_mask(mask_net.net.mask, data) # self.no_sig = mask_net.net.mask+0.5 # Reshape representation if temporal attributions if return_temporal_attributions: attributions = attributions.reshape( (-1, data.shape[1]) + data.shape[1:] ) # Reshape as a tuple attributions = (attributions,) # Format attributions and return return _format_output(is_inputs_tuple, attributions)