in captum/attr/_models/base.py [0:0]
def configure_interpretable_embedding_layer(model, embedding_layer_name="embedding"):
r"""
This method wraps model's embedding layer with an interpretable embedding
layer that allows us to access the embeddings through their indices.
Args:
model (torch.nn.Model): An instance of PyTorch model that contains embeddings.
embedding_layer_name (str, optional): The name of the embedding layer
in the `model` that we would like to make interpretable.
Returns:
interpretable_emb (tensor): An instance of `InterpretableEmbeddingBase`
embedding layer that wraps model's embedding layer that is being
accessed through `embedding_layer_name`.
Examples::
>>> # Let's assume that we have a DocumentClassifier model that
>>> # has a word embedding layer named 'embedding'.
>>> # To make that layer interpretable we need to execute the
>>> # following command:
>>> net = DocumentClassifier()
>>> interpretable_emb = configure_interpretable_embedding_layer(net,
>>> 'embedding')
>>> # then we can use interpretable embedding to convert our
>>> # word indices into embeddings.
>>> # Let's assume that we have the following word indices
>>> input_indices = torch.tensor([1, 0, 2])
>>> # we can access word embeddings for those indices with the command
>>> # line stated below.
>>> input_emb = interpretable_emb.indices_to_embeddings(input_indices)
>>> # Let's assume that we want to apply integrated gradients to
>>> # our model and that target attribution class is 3
>>> ig = IntegratedGradients(net)
>>> attribution = ig.attribute(input_emb, target=3)
>>> # after we finish the interpretation we need to remove
>>> # interpretable embedding layer with the following command:
>>> remove_interpretable_embedding_layer(net, interpretable_emb)
"""
embedding_layer = _get_deep_layer_name(model, embedding_layer_name)
assert (
embedding_layer.__class__ is not InterpretableEmbeddingBase
), "InterpretableEmbeddingBase has already been configured for layer {}".format(
embedding_layer_name
)
warnings.warn(
"In order to make embedding layers more interpretable they will "
"be replaced with an interpretable embedding layer which wraps the "
"original embedding layer and takes word embedding vectors as inputs of "
"the forward function. This allows us to generate baselines for word "
"embeddings and compute attributions for each embedding dimension. "
"The original embedding layer must be set "
"back by calling `remove_interpretable_embedding_layer` function "
"after model interpretation is finished. "
)
interpretable_emb = InterpretableEmbeddingBase(
embedding_layer, embedding_layer_name
)
_set_deep_layer_value(model, embedding_layer_name, interpretable_emb)
return interpretable_emb