neuron_explainer/activations/attention_utils.py (80 lines of code) (raw):

""" Contains math utilities for converting from flattened representations of attention activations (which are a scalar per token pair) to nested lists. The inner lists are attention activations related to attention from the same token (to different tokens). Tested in ./test_attention_utils.py. """ import math import numpy as np def _inverse_triangular_number(n: int) -> int: # the m'th triangular number t_m satisfies t_m = m(m+1)/2 # this function asserts that n is a triangular number, and returns the unique m such that t_m = n # this is used to infer the number of sequence tokens from the number of activations assert n >= 0 m: int = ( math.floor(math.sqrt(1 + 8 * n)) - 1 ) // 2 # from quadratic formula applied to m(m+1)/2 = n assert m * (m + 1) // 2 == n return m def get_max_num_attended_to_sequence_tokens(num_sequence_tokens: int, num_activations: int) -> int: # Attended to sequences are assumed to increase in length up to a maximum length, and then stay at that # length for the remainder of the sequence. The maximum attended to sequence length is at most equal to the sequence length, # but is permitted to be less num_sequence_token_pairs = num_sequence_tokens * (num_sequence_tokens + 1) // 2 if num_activations == num_sequence_token_pairs: # the maximum attended to sequence length is equal to the sequence length return num_sequence_tokens else: # the maximum attended to sequence length is less than the sequence length, and assert num_activations < num_sequence_token_pairs num_missing_activations = num_sequence_token_pairs - num_activations num_missing_sequence_tokens = _inverse_triangular_number(num_missing_activations) max_num_attended_to_sequence_tokens = num_sequence_tokens - num_missing_sequence_tokens assert max_num_attended_to_sequence_tokens > 0 return max_num_attended_to_sequence_tokens def get_attended_to_sequence_length_per_sequence_token( num_sequence_tokens: int, max_num_attended_to_sequence_tokens: int ) -> list[int]: # given a num_sequence_tokens and a max_num_attended_to_sequence_tokens, return a list of length num_sequence_tokens # where the ith element is the length of the attended to sequence for the ith sequence token. # The length of the attended to sequence starts at 1, increases up to max_num_attended_to_sequence_tokens, by 1 with each # token, and then stays at max_num_attended_to_sequence_tokens for the remainder of the sequence assert num_sequence_tokens >= max_num_attended_to_sequence_tokens attended_to_sequence_lengths = list(range(1, max_num_attended_to_sequence_tokens + 1)) if num_sequence_tokens > max_num_attended_to_sequence_tokens: attended_to_sequence_lengths.extend( [ max_num_attended_to_sequence_tokens for _ in range(num_sequence_tokens - max_num_attended_to_sequence_tokens) ] ) return attended_to_sequence_lengths def get_attended_to_sequence_lengths(num_sequence_tokens: int, num_activations: int) -> list[int]: max_num_attended_to_sequence_tokens = get_max_num_attended_to_sequence_tokens( num_sequence_tokens, num_activations ) return get_attended_to_sequence_length_per_sequence_token( num_sequence_tokens, max_num_attended_to_sequence_tokens ) def _convert_flattened_index_to_unflattened_index_assuming_square_matrix( flat_index: int, ) -> tuple[int, int]: # this con n = math.floor((-1 + math.sqrt(1 + 8 * flat_index)) / 2) m = flat_index - n * (n + 1) // 2 return n, m def convert_flattened_index_to_unflattened_index( flattened_index: int, num_sequence_tokens: int | None = None, num_activations: int | None = None, ) -> tuple[int, int]: # given a flattened index, return the unflattened index # if the attention matrix is square (most common), then the flattened_index uniquely determines the index within the square matrix # if the attention matrix has more rows (sequence tokens) than columns (attended-to sequence tokens), then num_sequence_tokens # and num_activations are required to determine the index within the matrix # specify both num_sequence_tokens and num_activations, or neither assert not (num_sequence_tokens is None) ^ (num_activations is None) if ( num_sequence_tokens is None or num_activations == num_sequence_tokens * (num_sequence_tokens + 1) // 2 ): assume_square_matrix = True else: assume_square_matrix = False if assume_square_matrix: return _convert_flattened_index_to_unflattened_index_assuming_square_matrix(flattened_index) else: assert num_sequence_tokens is not None assert num_activations is not None assert flattened_index < num_activations sequence_lengths = get_attended_to_sequence_lengths(num_sequence_tokens, num_activations) sequence_lengths_cumsum = np.cumsum([0] + sequence_lengths) sequence_index = int( np.searchsorted(sequence_lengths_cumsum, flattened_index, side="right") - 1 ) assert sequence_lengths_cumsum[sequence_index] <= flattened_index, ( sequence_lengths_cumsum[sequence_index], flattened_index, ) assert sequence_lengths_cumsum[sequence_index + 1] >= flattened_index, ( sequence_lengths_cumsum[sequence_index + 1], flattened_index, ) index_within_sequence = flattened_index - sequence_lengths_cumsum[sequence_index] return sequence_index, index_within_sequence