in captum/concept/_core/tcav.py [0:0]
def generate_activation(self, layers: Union[str, List], concept: Concept) -> None:
r"""
Computes layer activations for the specified `concept` and
the list of layer(s) `layers`.
Args:
layers (str, list[str]): A list of layer names or a layer name
that is used to compute layer activations for the
specific `concept`.
concept (Concept): A single Concept object that provides access
to concept examples using a data iterator.
"""
layers = [layers] if isinstance(layers, str) else layers
layer_modules = [_get_module_from_name(self.model, layer) for layer in layers]
layer_act = LayerActivation(self.model, layer_modules)
assert concept.data_iter is not None, (
"Data iterator for concept id:",
"{} must be specified".format(concept.id),
)
for i, examples in enumerate(concept.data_iter):
activations = layer_act.attribute.__wrapped__( # type: ignore
layer_act,
examples,
attribute_to_layer_input=self.attribute_to_layer_input,
)
for activation, layer_name in zip(activations, layers):
activation = torch.reshape(activation, (activation.shape[0], -1))
AV.save(
self.save_path,
self.model_id,
concept.identifier,
layer_name,
activation.detach(),
str(i),
)