def get_max_num_attended_to_sequence_tokens()

in neuron_explainer/activations/attention_utils.py [0:0]


def get_max_num_attended_to_sequence_tokens(num_sequence_tokens: int, num_activations: int) -> int:
    # Attended to sequences are assumed to increase in length up to a maximum length, and then stay at that
    # length for the remainder of the sequence. The maximum attended to sequence length is at most equal to the sequence length,
    # but is permitted to be less
    num_sequence_token_pairs = num_sequence_tokens * (num_sequence_tokens + 1) // 2
    if num_activations == num_sequence_token_pairs:
        # the maximum attended to sequence length is equal to the sequence length
        return num_sequence_tokens
    else:
        # the maximum attended to sequence length is less than the sequence length, and
        assert num_activations < num_sequence_token_pairs
        num_missing_activations = num_sequence_token_pairs - num_activations
        num_missing_sequence_tokens = _inverse_triangular_number(num_missing_activations)
        max_num_attended_to_sequence_tokens = num_sequence_tokens - num_missing_sequence_tokens
        assert max_num_attended_to_sequence_tokens > 0
        return max_num_attended_to_sequence_tokens