neuron_explainer/activations/derived_scalars/tokens.py (161 lines of code) (raw):
"""
This file contains utilities and a class for converting token-space vectors to and from a pydantic
base class summarizing them in terms of the extremal entries in the vector and their associated
tokens. The class can be used in InteractiveModel responses.
"""
import math
import torch
from pydantic import validator
from neuron_explainer.activations.derived_scalars.least_common_tokens import (
LEAST_COMMON_GPT2_TOKEN_STRS,
)
from neuron_explainer.models.model_context import ModelContext
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable
@immutable
class TokenAndRawScalar(CamelCaseBaseModel):
token: str
scalar: float
@validator("scalar")
def check_scalar(cls, scalar: float) -> float:
assert math.isfinite(scalar), "Scalar value must be a finite number"
return scalar
@immutable
class TokenAndScalar(TokenAndRawScalar):
normalized_scalar: float
@validator("normalized_scalar")
def check_normalized_scalar(cls, normalized_scalar: float) -> float:
assert math.isfinite(normalized_scalar), "Normalized scalar value must be a finite number"
return normalized_scalar
@immutable
class TopTokens(CamelCaseBaseModel):
"""
Contains two lists of tokens and associated scalars: one for the highest-scoring tokens and one
for the lowest-scoring tokens, according to some way of scoring tokens. For example, this could
be used to represent the top upvoted and downvoted "logit lens" tokens. An instance of this
class is scoped to a single node. The set of tokens eligible for scoring is typically just the
model's entire vocabulary. Each list is sorted from largest to smallest absolute value for the
associated scalar.
"""
top: list[TokenAndScalar]
bottom: list[TokenAndScalar]
def package_top_t_tokens(
model_context: ModelContext,
top_t_upvoted_token_ints_tensor: torch.Tensor,
top_t_upvoted_token_weights_tensor: torch.Tensor,
norm_top_t_upvoted_token_weights_tensor: torch.Tensor,
) -> list[list[TokenAndScalar]]:
"""
Convert tensors of top t upvoted token ints, weights, and normalized weights into a list of
lists of TokenAndScalar, one list per node.
"""
n_nodes, n_tokens = top_t_upvoted_token_ints_tensor.shape
top_t_upvoted_token_strings = [
model_context.decode_token_list(top_t_upvoted_token_ints_tensor[i].tolist())
for i in range(top_t_upvoted_token_ints_tensor.shape[0])
]
top_t_upvoted_token_weights = top_t_upvoted_token_weights_tensor.tolist()
norm_top_t_upvoted_token_weights = norm_top_t_upvoted_token_weights_tensor.tolist()
token_and_weight_data_for_all_nodes = []
# for each row of the tensor, zip the results into a list of TokenAndRawScalar for the relevant node
for node_index in range(n_nodes):
token_and_weight_data_for_this_node = []
# zip the results into a list of TokenAndRawScalar for this node
for token_index in range(n_tokens):
token_and_weight_data_for_this_node.append(
TokenAndScalar(
token=top_t_upvoted_token_strings[node_index][token_index],
scalar=top_t_upvoted_token_weights[node_index][token_index],
normalized_scalar=norm_top_t_upvoted_token_weights[node_index][token_index],
)
)
token_and_weight_data_for_all_nodes.append(token_and_weight_data_for_this_node)
return token_and_weight_data_for_all_nodes
def get_top_t_tokens_maybe_excluding_least_common(
token_writes_tensor: torch.Tensor,
top_t_tokens: int,
largest: bool,
least_common_tokens_as_ints: list[int] | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return the top t tokens and their weights, optionally excluding the least common token ints,
passed as an argument. Sorted by largest or smallest absolute value, as specified by the
'largest' argument.
"""
if least_common_tokens_as_ints is not None:
token_writes_tensor[:, least_common_tokens_as_ints] = (
float("-inf") if largest else float("inf")
)
(
top_t_upvoted_token_weights_tensor,
top_t_upvoted_token_ints_tensor,
) = token_writes_tensor.topk(k=top_t_tokens, largest=largest)
assert torch.isfinite(
top_t_upvoted_token_weights_tensor
).all(), "Top token weights should only contain finite values"
return top_t_upvoted_token_weights_tensor, top_t_upvoted_token_ints_tensor
def get_most_upvoted_and_downvoted_tokens_for_nodes(
model_context: ModelContext,
token_writes_tensor: torch.Tensor,
top_t_tokens: int,
flip_upvoted_and_downvoted: bool = False,
) -> list[TopTokens]:
"""
Convert a 2D token_writes_tensor to the most positive (upvoted) and most negative (downvoted) vocab tokens per row,
with weights corresponding to how upvoted or downvoted each token is. Return a list (indexed by row index) of
TopTokens, each of which contains a list of TokenAndScalar for the most upvoted tokens and
a list of TokenAndScalar for the most downvoted tokens.
Note that the scalars in TokenAndScalar are referred to as 'weights', despite being held in an object called
TokenAndScalar. The weights returned here include normalized versions (normalized to max(abs(weight))).
"""
if model_context.get_encoding().name == "gpt2":
# for GPT-2, we exclude tokens string-matching to the least common tokens
# from the top_t tokens displayed
least_common_tokens_as_ints = model_context.encode_token_str_list(
LEAST_COMMON_GPT2_TOKEN_STRS
)
else:
least_common_tokens_as_ints = None
(
top_t_upvoted_token_weights_tensor,
top_t_upvoted_token_ints_tensor,
) = get_top_t_tokens_maybe_excluding_least_common(
token_writes_tensor,
top_t_tokens,
largest=True,
least_common_tokens_as_ints=least_common_tokens_as_ints,
)
(
top_t_downvoted_token_weights_tensor,
top_t_downvoted_token_ints_tensor,
) = get_top_t_tokens_maybe_excluding_least_common(
token_writes_tensor,
top_t_tokens,
largest=False,
least_common_tokens_as_ints=least_common_tokens_as_ints,
)
max_abs_token_writes_tensor = torch.max(
top_t_upvoted_token_weights_tensor[:, 0:1].abs(),
top_t_downvoted_token_weights_tensor[:, 0:1].abs(),
)
def safe_divide(
numerator_tensor: torch.Tensor, denominator_tensor: torch.Tensor
) -> torch.Tensor:
assert torch.isfinite(
numerator_tensor
).all(), "Numerator tensor should only contain finite values"
assert torch.isfinite(
denominator_tensor
).all(), "Denominator tensor should only contain finite values"
return torch.where(
denominator_tensor == 0,
torch.zeros_like(numerator_tensor),
numerator_tensor / denominator_tensor,
)
norm_top_t_upvoted_token_weights_tensor = safe_divide(
top_t_upvoted_token_weights_tensor, max_abs_token_writes_tensor
)
norm_top_t_downvoted_token_weights_tensor = safe_divide(
top_t_downvoted_token_weights_tensor, max_abs_token_writes_tensor
)
normalized_most_upvoted_tokens_for_all_nodes = package_top_t_tokens(
model_context,
top_t_upvoted_token_ints_tensor,
top_t_upvoted_token_weights_tensor,
norm_top_t_upvoted_token_weights_tensor,
)
normalized_most_downvoted_tokens_for_all_nodes = package_top_t_tokens(
model_context,
top_t_downvoted_token_ints_tensor,
top_t_downvoted_token_weights_tensor,
norm_top_t_downvoted_token_weights_tensor,
)
# zip the results into a list of TopTokens
top_tokens_list = []
for node_index in range(len(normalized_most_upvoted_tokens_for_all_nodes)):
if flip_upvoted_and_downvoted:
normalized_most_upvoted_tokens_for_node = (
normalized_most_downvoted_tokens_for_all_nodes[node_index]
)
normalized_most_downvoted_tokens_for_node = (
normalized_most_upvoted_tokens_for_all_nodes[node_index]
)
else:
normalized_most_upvoted_tokens_for_node = normalized_most_upvoted_tokens_for_all_nodes[
node_index
]
normalized_most_downvoted_tokens_for_node = (
normalized_most_downvoted_tokens_for_all_nodes[node_index]
)
top_tokens_list.append(
TopTokens(
top=normalized_most_upvoted_tokens_for_node,
bottom=normalized_most_downvoted_tokens_for_node,
)
)
assert len(top_tokens_list) == token_writes_tensor.shape[0]
return top_tokens_list