in neuron_explainer/activations/derived_scalars/attention.py [0:0]
def make_reshape_fn(dst: DerivedScalarType) -> Callable:
"""Create a reshape function to apply to the output tensors."""
output_dim = dst.shape_spec_per_token_sequence
error_msg = f"Unexpected output_dim: {output_dim}. Please add a reshape function."
if len(output_dim) == 2:
# Regular activations are already 2d and don't need to be reshaped.
assert output_dim[0] == Dimension.SEQUENCE_TOKENS
assert output_dim[1].is_model_intrinsic
reshape_fn = lambda x: x
elif len(output_dim) == 3:
assert output_dim[0] == Dimension.SEQUENCE_TOKENS
assert output_dim[2].is_model_intrinsic
if output_dim[1] == Dimension.ATTENDED_TO_SEQUENCE_TOKENS:
# E.g. attention activations that are indexed both by current-token and token-attended-to.
# Here, we move the two indexing dimensions to the end, we extract the lower triangle indices,
# we flatten the lower triangle indices into a single dimension, and we move that dimension to the front.
reshape_fn = lambda x: flatten_lower_triangle(x.permute(2, 0, 1)).permute(1, 0)
elif output_dim[1] == Dimension.ATTN_HEADS:
# E.g. attention activations that are split by attention heads.
# Here, we merge the two model dimensions into one.
reshape_fn = lambda x: x.reshape(x.shape[0], -1)
else:
raise NotImplementedError(error_msg)
elif len(output_dim) == 4:
assert output_dim[0] == Dimension.SEQUENCE_TOKENS
assert output_dim[3].is_model_intrinsic
if (
output_dim[1] == Dimension.ATTENDED_TO_SEQUENCE_TOKENS
and output_dim[2] == Dimension.ATTN_HEADS
):
# Here, we move the two indexing dimensions to the end, we extract the lower triangle indices,
# we flatten the lower triangle indices into a single dimension, and we move that dimension to the front.
# Then we merge the merged input dimension with the attention heads dimension.
reshape_fn = (
lambda x: flatten_lower_triangle(x.permute(2, 3, 0, 1))
.permute(2, 0, 1)
.reshape(-1, x.shape[3])
)
else:
raise NotImplementedError(error_msg)
else:
raise NotImplementedError(error_msg)
return reshape_fn