def make_reshape_fn()

in neuron_explainer/activations/derived_scalars/attention.py [0:0]


def make_reshape_fn(dst: DerivedScalarType) -> Callable:
    """Create a reshape function to apply to the output tensors."""

    output_dim = dst.shape_spec_per_token_sequence
    error_msg = f"Unexpected output_dim: {output_dim}. Please add a reshape function."

    if len(output_dim) == 2:
        # Regular activations are already 2d and don't need to be reshaped.
        assert output_dim[0] == Dimension.SEQUENCE_TOKENS
        assert output_dim[1].is_model_intrinsic
        reshape_fn = lambda x: x

    elif len(output_dim) == 3:
        assert output_dim[0] == Dimension.SEQUENCE_TOKENS
        assert output_dim[2].is_model_intrinsic
        if output_dim[1] == Dimension.ATTENDED_TO_SEQUENCE_TOKENS:
            # E.g. attention activations that are indexed both by current-token and token-attended-to.
            # Here, we move the two indexing dimensions to the end, we extract the lower triangle indices,
            # we flatten the lower triangle indices into a single dimension, and we move that dimension to the front.
            reshape_fn = lambda x: flatten_lower_triangle(x.permute(2, 0, 1)).permute(1, 0)
        elif output_dim[1] == Dimension.ATTN_HEADS:
            # E.g. attention activations that are split by attention heads.
            # Here, we merge the two model dimensions into one.
            reshape_fn = lambda x: x.reshape(x.shape[0], -1)
        else:
            raise NotImplementedError(error_msg)

    elif len(output_dim) == 4:
        assert output_dim[0] == Dimension.SEQUENCE_TOKENS
        assert output_dim[3].is_model_intrinsic
        if (
            output_dim[1] == Dimension.ATTENDED_TO_SEQUENCE_TOKENS
            and output_dim[2] == Dimension.ATTN_HEADS
        ):
            # Here, we move the two indexing dimensions to the end, we extract the lower triangle indices,
            # we flatten the lower triangle indices into a single dimension, and we move that dimension to the front.
            # Then we merge the merged input dimension with the attention heads dimension.
            reshape_fn = (
                lambda x: flatten_lower_triangle(x.permute(2, 3, 0, 1))
                .permute(2, 0, 1)
                .reshape(-1, x.shape[3])
            )
        else:
            raise NotImplementedError(error_msg)

    else:
        raise NotImplementedError(error_msg)
    return reshape_fn