import copy
import torch as th

from captum.attr._utils.attribution import PerturbationAttribution
from captum.log import log_usage
from captum._utils.common import (
    _format_baseline,
    _format_inputs,
    _format_output,
    _is_tuple,
    _validate_input,
)
from captum._utils.typing import (
    BaselineType,
    TargetType,
    TensorOrTupleOfTensorsGeneric,
)

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from typing import Any, Callable, Tuple

from tint.utils import TensorDataset, _add_temporal_mask, default_collate
from abstudy.gatemasknn_no_both import GateMaskNet


class GateMask(PerturbationAttribution):
    """
    Extremal masks.

    This method extends the work of Fong et al. and Crabbé et al. by allowing
    the perturbation function to be learnt. This is in addition to the learnt
    mask. For instance, this perturbation function can be learnt with a RNN
    while Crabbé et al. only consider fixed perturbations: Gaussian blur
    and fade to moving average.

    Args:
        forward_func (callable): The forward function of the model or any
            modification of it.

    References:
        #. `Learning Perturbations to Explain Time Series Predictions <https://arxiv.org/abs/2305.18840>`_
        #. `Understanding Deep Networks via Extremal Perturbations and Smooth Masks <https://arxiv.org/abs/1910.08485>`_

    Examples:
        >>> import torch as th
        >>> from tint.attr import ExtremalMask
        >>> from tint.models import MLP
        <BLANKLINE>
        >>> inputs = th.rand(8, 7, 5)
        >>> data = th.rand(32, 7, 5)
        >>> mlp = MLP([5, 3, 1])
        <BLANKLINE>
        >>> explainer = ExtremalMask(mlp)
        >>> attr = explainer.attribute(inputs)
    """

    def __init__(self, forward_func: Callable) -> None:
        super().__init__(forward_func=forward_func)

    @log_usage()
    def attribute(
        self,
        inputs: TensorOrTupleOfTensorsGeneric,
        baselines: BaselineType = None,
        target: TargetType = None,
        additional_forward_args: Any = None,
        trainer: Trainer = None,
        mask_net: GateMaskNet = None,
        batch_size: int = 32,
        temporal_additional_forward_args: Tuple[bool] = None,
        return_temporal_attributions: bool = False,
        win_size: int = 5,
        sigma: float = 0.5,
    ) -> TensorOrTupleOfTensorsGeneric:

        # Keeps track whether original input is a tuple or not before
        # converting it into a tuple.
        is_inputs_tuple = _is_tuple(inputs)
        inputs = _format_inputs(inputs)

        # Format and validate baselines
        baselines = _format_baseline(baselines, inputs)
        _validate_input(inputs, baselines)

        # Init trainer if not provided
        if trainer is None:
            trainer = Trainer(max_epochs=100)
        else:
            trainer = copy.deepcopy(trainer)

        # Assert only one input, as the Retain only accepts one
        assert (
            len(inputs) == 1
        ), "Multiple inputs are not accepted for this method"
        data = inputs[0]
        baseline = baselines[0]

        # If return temporal attr, we expand the input data
        # and multiply it with a lower triangular mask
        if return_temporal_attributions:
            data, additional_forward_args, _ = _add_temporal_mask(
                inputs=data,
                additional_forward_args=additional_forward_args,
                temporal_additional_forward_args=temporal_additional_forward_args,
            )

        # Init MaskNet if not provided
        if mask_net is None:
            mask_net = GateMaskNet(forward_func=self.forward_func)

        # Init model
        mask_net.net.init(input_size=data.shape, batch_size=batch_size, win_size=win_size,
                          sigma=sigma, n_epochs=trainer.max_epochs)

        # Prepare data
        dataloader = DataLoader(
            TensorDataset(
                *(data, data, baseline, target, *additional_forward_args)
                if additional_forward_args is not None
                else (data, data, baseline, target, None)
            ),
            batch_size=batch_size,
            collate_fn=default_collate,
        )

        # Fit model
        trainer.fit(mask_net, train_dataloaders=dataloader)

        # Set model to eval mode and cast it to device
        mask_net.eval()
        mask_net.to(data.device)

        # Get attributions as mask representation
        attributions = mask_net.net.representation(data)
        # self.learn_sig = mask_net.net.refactor_mask(mask_net.net.mask, data)
        # self.no_sig = mask_net.net.mask+0.5

        # Reshape representation if temporal attributions
        if return_temporal_attributions:
            attributions = attributions.reshape(
                (-1, data.shape[1]) + data.shape[1:]
            )

        # Reshape as a tuple
        attributions = (attributions,)

        # Format attributions and return
        return _format_output(is_inputs_tuple, attributions)
