def configure_interpretable_embedding_layer()

in captum/attr/_models/base.py [0:0]


def configure_interpretable_embedding_layer(model, embedding_layer_name="embedding"):
    r"""
    This method wraps model's embedding layer with an interpretable embedding
    layer that allows us to access the embeddings through their indices.

    Args:

        model (torch.nn.Model): An instance of PyTorch model that contains embeddings.
        embedding_layer_name (str, optional): The name of the embedding layer
                    in the `model` that we would like to make interpretable.

    Returns:

        interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase`
                    embedding layer that wraps model's embedding layer that is being
                    accessed through `embedding_layer_name`.

    Examples::

                >>> # Let's assume that we have a DocumentClassifier model that
                >>> # has a word embedding layer named 'embedding'.
                >>> # To make that layer interpretable we need to execute the
                >>> # following command:
                >>> net = DocumentClassifier()
                >>> interpretable_emb = configure_interpretable_embedding_layer(net,
                >>>    'embedding')
                >>> # then we can use interpretable embedding to convert our
                >>> # word indices into embeddings.
                >>> # Let's assume that we have the following word indices
                >>> input_indices = torch.tensor([1, 0, 2])
                >>> # we can access word embeddings for those indices with the command
                >>> # line stated below.
                >>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
                >>> # Let's assume that we want to apply integrated gradients to
                >>> # our model and that target attribution class is 3
                >>> ig = IntegratedGradients(net)
                >>> attribution = ig.attribute(input_emb, target=3)
                >>> # after we finish the interpretation we need to remove
                >>> # interpretable embedding layer with the following command:
                >>> remove_interpretable_embedding_layer(net, interpretable_emb)

    """
    embedding_layer = _get_deep_layer_name(model, embedding_layer_name)
    assert (
        embedding_layer.__class__ is not InterpretableEmbeddingBase
    ), "InterpretableEmbeddingBase has already been configured for layer {}".format(
        embedding_layer_name
    )
    warnings.warn(
        "In order to make embedding layers more interpretable they will "
        "be replaced with an interpretable embedding layer which wraps the "
        "original embedding layer and takes word embedding vectors as inputs of "
        "the forward function. This allows us to generate baselines for word "
        "embeddings and compute attributions for each embedding dimension. "
        "The original embedding layer must be set "
        "back by calling `remove_interpretable_embedding_layer` function "
        "after model interpretation is finished. "
    )
    interpretable_emb = InterpretableEmbeddingBase(
        embedding_layer, embedding_layer_name
    )
    _set_deep_layer_value(model, embedding_layer_name, interpretable_emb)
    return interpretable_emb