neuron_explainer/activation_server/requests_and_responses.py (337 lines of code) (raw):
"""
Request and response definitions. This file shouldn't contain any functions, other than those
defined on the dataclasses.
Requests to InteractiveModel have two parts: an InferenceRequestSpec, specifying how to
run inference to obtain activations, and a ProcessingRequestSpec, specifying how to process those
activations to obtain derived scalars. An InferenceSubRequest contains information for a single
inference step (forward and optionally also backward pass), and one or more ProcessingRequestSpecs
to process activations from the same inference step. A BatchedRequest contains one or more
InferenceSubRequests, whose inference steps are run in parallel, and whose processing steps are
performed sequentially.
InferenceRequests are analogous to single InferenceSubRequests, and are processed stand-alone rather
than in a batch.
TdbRequests compactly specify the information in InferenceRequestSpec and ProcessingRequestSpec,
with only the degrees of freedom permitted by the TDB UI. TdbRequests are converted to
InferenceSubRequests and ProcessingRequestSpecs in tdb_conversions.py. BatchedTdbRequests are
analogous to BatchedRequests.
"""
import math
from enum import Enum
from typing import Any, Literal, Union
import torch
from pydantic import root_validator
from neuron_explainer.activation_server.load_neurons import NodeIdAndDatasets
from neuron_explainer.activation_server.read_routes import TokenAndAttentionScalars
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.activations.derived_scalars.indexing import (
AblationSpec,
MirroredActivationIndex,
MirroredNodeIndex,
MirroredTraceConfig,
NodeAblation,
NodeToTrace,
)
from neuron_explainer.activations.derived_scalars.tokens import TopTokens
from neuron_explainer.models.model_component_registry import Dimension, LayerIndex, PassType
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable
########## Types used by multiple requests and/or responses ##########
# NOTE: other than TDB_REQUEST_SPEC, these all contain params for processing activations
# only, and not for specifying how to run inference to obtain those activations
class SpecType(Enum):
ACTIVATIONS_REQUEST_SPEC = "activations_request_spec"
DERIVED_SCALARS_REQUEST_SPEC = "derived_scalars_request_spec"
DERIVED_ATTENTION_SCALARS_REQUEST_SPEC = "derived_attention_scalars_request_spec"
MULTIPLE_TOP_K_DERIVED_SCALARS_REQUEST_SPEC = "multiple_top_k_derived_scalars_request_spec"
SCORED_TOKENS_REQUEST_SPEC = "scored_tokens_request_spec"
TDB_REQUEST_SPEC = "tdb_request_spec"
TOKEN_PAIR_ATTRIBUTION_REQUEST_SPEC = "token_pair_attribution_request_spec"
class ProcessingResponseDataType(Enum):
DERIVED_SCALARS_RESPONSE_DATA = "derived_scalars_response_data"
DERIVED_ATTENTION_SCALARS_RESPONSE_DATA = "derived_attention_scalars_response_data"
MULTIPLE_TOP_K_DERIVED_SCALARS_RESPONSE_DATA = "multiple_top_k_derived_scalars_response_data"
SCORED_TOKENS_RESPONSE_DATA = "scored_tokens_response_data"
TOKEN_PAIR_ATTRIBUTION_RESPONSE_DATA = "token_pair_attribution_response_data"
class LossFnName(str, Enum):
LOGIT_DIFF = "logit_diff"
LOGIT_MINUS_MEAN = "logit_minus_mean"
PROBS = "probs"
ZERO = "zero"
class LossFnConfig(CamelCaseBaseModel):
name: LossFnName
target_tokens: list[str] | None = None
distractor_tokens: list[str] | None = None
@immutable
class InferenceRequestSpec(CamelCaseBaseModel):
"""The minimum specification for performing a forward and/or backward pass on a model, with hooks at some set of layers."""
prompt: str
ablation_specs: list[AblationSpec] | None = None
# note that loss_fn_config and trace_config are mutually exclusive
loss_fn_config: LossFnConfig | None = None
# used for performing a backward pass from an internal point within the network
trace_config: MirroredTraceConfig | None = None
# used for tracing latent activations back to the activations for the DSTs which they encode
activation_index_for_within_layer_grad: MirroredActivationIndex | None = None
@immutable
class InferenceRequest(CamelCaseBaseModel):
inference_request_spec: InferenceRequestSpec
class InferenceData(CamelCaseBaseModel):
inference_time: float
memory_used_before: float | None
memory_used_after: float | None
log: str | None = None
loss: float | None = None
activation_value_for_backward_pass: float | None = None
@immutable
class InferenceAndTokenData(InferenceData):
tokens_as_ints: list[int]
tokens_as_strings: list[str]
@immutable
class InferenceResponse(CamelCaseBaseModel):
inference_and_token_data: InferenceAndTokenData
class GroupId(str, Enum):
"""Identifiers for groups in multi-top-k requests."""
ACT_TIMES_GRAD = "act_times_grad"
ACTIVATION = "activation"
DIRECT_WRITE_TO_GRAD = "direct_write_to_grad"
DIRECTION_WRITE = "direction_write"
LOGITS = "logits"
MLP_LAYER_WRITE = "mlp_layer_write"
# Used in situations where there's only one group.
SINGLETON = "singleton"
# Used for projecting write vectors of nodes to token space.
TOKEN_WRITE = "token_write"
# Used for projecting read vectors of nodes to token space.
TOKEN_READ = "token_read"
WRITE_NORM = "write_norm"
# Used for token pair attribution requests.
TOKEN_PAIR_ATTRIBUTION = "token_pair_attribution"
@property
def exclude_bottom_k(self) -> bool:
# if False, top k should return both the top k largest and smallest/(most negative) activations;
# otherwise, should return the top k largest only. Generally, exclude_bottom_k = True is
# appropriate for scalars that are non-negative (the values closest to 0 are not particularly interesting).
# exclude_bottom_k = False is appropriate for scalars that can be positive or negative (the most negative values
# may be interesting).
return self in {
GroupId.WRITE_NORM,
GroupId.ACTIVATION,
GroupId.LOGITS, # logits can be positive or negative, but generally we are interested the most likely
# tokens to be sampled, which are the most positive logits
}
########## Tensors ##########
class TensorType(Enum):
TENSOR_0D = "tensor_0d"
TENSOR_1D = "tensor_1d"
TENSOR_2D = "tensor_2d"
TENSOR_3D = "tensor_3d"
class TorchableTensor(CamelCaseBaseModel):
tensor_type: TensorType
value: Any
def torch(self) -> torch.Tensor:
return torch.tensor(self.value)
@immutable
class Tensor0D(TorchableTensor):
tensor_type: TensorType = TensorType.TENSOR_0D
value: float
@immutable
class Tensor1D(TorchableTensor):
tensor_type: TensorType = TensorType.TENSOR_1D
value: list[float]
@immutable
class Tensor2D(TorchableTensor):
tensor_type: TensorType = TensorType.TENSOR_2D
value: list[list[float]]
@immutable
class Tensor3D(TorchableTensor):
tensor_type: TensorType = TensorType.TENSOR_3D
value: list[list[list[float]]]
TensorND = Union[Tensor0D, Tensor1D, Tensor2D, Tensor3D]
########## Model info ##########
@immutable
class ModelInfoResponse(CamelCaseBaseModel):
model_name: str | None
has_mlp_autoencoder: bool
mlp_autoencoder_name: str | None
has_attention_autoencoder: bool
attention_autoencoder_name: str | None
n_layers: int
########## Derived scalars ##########
@immutable
class DerivedScalarsRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[
SpecType.DERIVED_SCALARS_REQUEST_SPEC
] = SpecType.DERIVED_SCALARS_REQUEST_SPEC
dst: DerivedScalarType
layer_index: LayerIndex
activation_index: int
normalize_activations_using_neuron_record: NodeIdAndDatasets | None = None
"""
If non-None, the response will include normalized activations. The max scalar used for
normalization will be the max scalar in the neuron record specified by the NodeIdAndDatasets.
"""
pass_type: PassType = PassType.FORWARD
num_top_tokens: int | None = None
"""
If non-None, return the top and bottom tokens for the node, according to the scoring
methodology associated with the derived scalar type.
"""
@immutable
class DerivedScalarsRequest(InferenceRequest):
derived_scalars_request_spec: DerivedScalarsRequestSpec
@immutable
class DerivedScalarsResponseData(CamelCaseBaseModel):
response_data_type: ProcessingResponseDataType = (
ProcessingResponseDataType.DERIVED_SCALARS_RESPONSE_DATA
)
activations: list[float]
normalized_activations: list[float] | None
"""
The same activations, but normalized to [0, 1] using the max scalar in the specified neuron
record. Only set if normalize_activations_using_neuron_record was specified in the request.
"""
node_indices: list[MirroredNodeIndex]
top_tokens: TopTokens | None
"""
While this response covers multiple nodes, those nodes differ only in the sequence token index:
they all correspond to a single component (per go/tdb-terminology). Top tokens are the same for
all nodes associated with a single component, so we only need to return one set of top tokens
for the entire component. This will be None if num_top_tokens is None or if the activation was
zero, preventing the relevant write vector from being computed.
"""
@immutable
class DerivedScalarsResponse(InferenceResponse):
derived_scalars_response_data: DerivedScalarsResponseData
########## Derived attention scalars ##########
@immutable
class DerivedAttentionScalarsRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[
SpecType.DERIVED_ATTENTION_SCALARS_REQUEST_SPEC
] = SpecType.DERIVED_ATTENTION_SCALARS_REQUEST_SPEC
dst: DerivedScalarType
layer_index: LayerIndex
activation_index: int
normalize_activations_using_neuron_record: NodeIdAndDatasets | None = None
"""
If non-None, the response will include normalized activations. The max scalars used for
normalization will be the max scalars in the neuron record specified by the NodeIdAndDatasets.
"""
@immutable
class DerivedAttentionScalarsRequest(InferenceRequest):
derived_attention_scalars_request_spec: DerivedAttentionScalarsRequestSpec
@immutable
class DerivedAttentionScalarsResponseData(CamelCaseBaseModel):
response_data_type: ProcessingResponseDataType = (
ProcessingResponseDataType.DERIVED_ATTENTION_SCALARS_RESPONSE_DATA
)
token_and_attention_scalars_list: list[TokenAndAttentionScalars]
@immutable
class DerivedAttentionScalarsResponse(InferenceResponse):
derived_attention_scalars_response_data: DerivedAttentionScalarsResponseData
########## (Multi) top-k ##########
# This dataclass is not used in any requests or responses. It's used internally to represent a top-k
# operation performed as part of servicing a MultipleTopKDerivedScalarsRequest.
@immutable
class TopKParams(CamelCaseBaseModel):
dst_list: list[DerivedScalarType]
token_index: int | None
top_and_bottom_k: int | None = None
pass_type: PassType = PassType.FORWARD
exclude_bottom_k: bool = False
dimensions_to_keep_for_intermediate_sum: list[Dimension] = [
Dimension.SEQUENCE_TOKENS,
Dimension.ATTENDED_TO_SEQUENCE_TOKENS,
]
@immutable
class MultipleTopKDerivedScalarsRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[
SpecType.MULTIPLE_TOP_K_DERIVED_SCALARS_REQUEST_SPEC
] = SpecType.MULTIPLE_TOP_K_DERIVED_SCALARS_REQUEST_SPEC
dst_list_by_group_id: dict[GroupId, list[DerivedScalarType]]
# dsts for each group ID are assumed to have defined node_type,
# all node_types assumed to be distinct within a group_id, and all group_ids to
# contain the same set of node_types.
token_index: int | None
top_and_bottom_k: int | None = None
pass_type: PassType = PassType.FORWARD
dimensions_to_keep_for_intermediate_sum: list[Dimension] = [
Dimension.SEQUENCE_TOKENS,
Dimension.ATTENDED_TO_SEQUENCE_TOKENS,
]
def get_top_k_params_for_group_id(self, group_id: GroupId) -> TopKParams:
"""
A MultipleTopKDerivedScalarsRequestSpec object contains the information necessary to
generate multiple TopKParams objects, one for each group ID. This function returns the
TopKParams for a specific group ID.
"""
dst_list = self.dst_list_by_group_id[group_id]
exclude_bottom_k = group_id.exclude_bottom_k
# Convert the instance to a dictionary
data = self.dict()
# Remove the fields that are not needed in TopKDerivedScalarsRequestSpec
data.pop("dst_list_by_group_id")
data.pop("spec_type")
# Add the fields specific to TopKDerivedScalarsRequestSpec
data["dst_list"] = dst_list
data["exclude_bottom_k"] = exclude_bottom_k
return TopKParams(**data)
# All sub-requests within this request must have comparable prompts, since when top-k operations
# within the batch will union over node indices (within each spec name).
@immutable
class MultipleTopKDerivedScalarsRequest(InferenceRequest):
multiple_top_k_derived_scalars_request_spec: MultipleTopKDerivedScalarsRequestSpec
@immutable
class MultipleTopKDerivedScalarsResponseData(CamelCaseBaseModel):
response_data_type: ProcessingResponseDataType = (
ProcessingResponseDataType.MULTIPLE_TOP_K_DERIVED_SCALARS_RESPONSE_DATA
)
# Activations associated with top-k nodes for this sub-request, as well as top-k nodes with the
# same spec name in other (multi) top-k requests in this batched request.
activations_by_group_id: dict[GroupId, list[float]]
# Indices for top-k nodes associated with this request, as well as top-k nodes with the same
# spec name in other (multi) top-k requests in this batched request.
node_indices: list[MirroredNodeIndex]
vocab_token_strings_for_indices: list[str | None] | None
# sum_... entries indicate total of all activations in group, including non-top-k activations
intermediate_sum_activations_by_dst_by_group_id: dict[
GroupId, dict[DerivedScalarType, TensorND]
]
@root_validator
def check_consistency(cls, values: dict[str, Any]) -> dict[str, Any]:
activations_by_group_id = values.get("activations_by_group_id")
assert activations_by_group_id is not None
node_indices = values.get("node_indices")
assert node_indices is not None
vocab_token_strings_for_indices = values.get("vocab_token_strings_for_indices")
for group_id, activations in activations_by_group_id.items():
assert len(node_indices) == len(activations), (
f"Expected len(node_indices) == len(activations) for group_id {group_id},"
f" but got len(node_indices)={len(node_indices)}, len(activations)={len(activations)}"
)
assert all(math.isfinite(activation) for activation in activations), (
f"Expected all activations to be finite for group_id {group_id},"
f" but got activations={activations}"
)
if vocab_token_strings_for_indices is not None:
assert len(node_indices) == len(vocab_token_strings_for_indices), (
f"Expected len(node_indices) == len(vocab_token_strings_for_indices),"
f" but got len(node_indices)={len(node_indices)}, len(vocab_token_strings_for_indices)={len(vocab_token_strings_for_indices)}"
)
return values
@immutable
class MultipleTopKDerivedScalarsResponse(InferenceResponse):
multiple_top_k_derived_scalars_response_data: MultipleTopKDerivedScalarsResponseData
########## Scored tokens ##########
class TokenScoringType(Enum):
"""Methods by which vocab tokens may be scored."""
# Score tokens by the degree to which this node directly upvotes them. This is basically the
# "logit lens".
UPVOTED_OUTPUT_TOKENS = "upvoted_output_tokens"
# Score tokens by the degree to which they directly upvote this node. Three flavors, each of
# which applies to both "raw" components like neurons and attention heads, as well as
# autoencoder latents:
# 1) Upvoting MLP nodes
# 2) Upvoting the Q part of attention nodes
# 3) Upvoting the K part of attention nodes
INPUT_TOKENS_THAT_UPVOTE_MLP = "input_tokens_that_upvote_mlp"
INPUT_TOKENS_THAT_UPVOTE_ATTN_Q = "input_tokens_that_upvote_attn_q"
INPUT_TOKENS_THAT_UPVOTE_ATTN_K = "input_tokens_that_upvote_attn_k"
@immutable
class ScoredTokensRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[SpecType.SCORED_TOKENS_REQUEST_SPEC] = SpecType.SCORED_TOKENS_REQUEST_SPEC
# How tokens should be scored.
token_scoring_type: TokenScoringType
# A value of e.g. 10 means 10 top and 10 bottom tokens.
num_tokens: int
# Which nodes do we want to get scored tokens for, and which DSTs and DST configs should we use?
# This request spec refers to another request spec and grabs those values from it.
depends_on_spec_name: str
@immutable
class ScoredTokensRequest(InferenceRequest):
scored_tokens_request_spec: ScoredTokensRequestSpec
@immutable
class ScoredTokensResponseData(CamelCaseBaseModel):
response_data_type: ProcessingResponseDataType = (
ProcessingResponseDataType.SCORED_TOKENS_RESPONSE_DATA
)
# These two lists are parallel and have the same length. "None" values in top_tokens_list
# indicate that the specified TokenScoringType does not apply to the corresponding node.
node_indices: list[MirroredNodeIndex]
top_tokens_list: list[TopTokens | None]
@immutable
class ScoredTokensResponse(InferenceResponse):
scored_tokens_response_data: ScoredTokensResponseData
########## TDB-specific ##########
class ComponentTypeForMlp(Enum):
"""The type of component / fundamental unit to use for MLP layers.
This determines which types of node appear in the node table to represent the MLP layers.
Neurons are the fundamental unit of MLP layers, but autoencoder latents are more interpretable.
"""
NEURON = "neuron"
AUTOENCODER_LATENT = "autoencoder_latent"
class ComponentTypeForAttention(Enum):
"""The type of component / fundamental unit to use for Attention layers.
This determines which types of node appear in the node table to represent the Attention layers.
Heads are the fundamental unit of Attention layers, but autoencoder latents are more interpretable.
"""
ATTENTION_HEAD = "attention_head"
AUTOENCODER_LATENT = "autoencoder_latent"
@immutable
class TdbRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[SpecType.TDB_REQUEST_SPEC] = SpecType.TDB_REQUEST_SPEC
prompt: str
target_tokens: list[str]
distractor_tokens: list[str]
component_type_for_mlp: ComponentTypeForMlp
"""Whether to use neurons or autoencoder latents as the basic unit for MLP layers."""
component_type_for_attention: ComponentTypeForAttention
"""Whether to use heads or autoencoder latents as the basic unit for attention layers."""
top_and_bottom_k_for_node_table: int
"""The number of top and bottom nodes to calculate for each column in the node table."""
hide_early_layers_when_ablating: bool
"""Whether to exclude layers before the first ablated layer from the results."""
node_ablations: list[NodeAblation] | None
upstream_node_to_trace: NodeToTrace | None
"""The primary node at which a gradient is being computed"""
downstream_node_to_trace: NodeToTrace | None
"""In the case where upstream_node_to_trace is an attention value subnode, you can also
provide a downstream node to trace. A gradient is first computed with respect to this downstream
node, and then the direct effect of the upstream node on this gradient direction is computed. A
gradient is then computed with respect to that quantity, propagated back to upstream activations.
In the case where no downstream node is provided, the loss is used as the "downstream node"."""
@immutable
class BatchedTdbRequest(CamelCaseBaseModel):
sub_requests: list[TdbRequestSpec]
########## Attribution ##########
@immutable
class TopTokensAttendedTo(CamelCaseBaseModel):
token_indices: list[int] # in sequence
attributions: list[float]
@immutable
class TokenPairAttributionRequestSpec(CamelCaseBaseModel):
# note: the spec_type field is not to be populated by the user at __init__, but is
# required for pydantic to distinguish between different XRequestSpec classes
spec_type: Literal[
SpecType.TOKEN_PAIR_ATTRIBUTION_REQUEST_SPEC
] = SpecType.TOKEN_PAIR_ATTRIBUTION_REQUEST_SPEC
num_tokens_attended_to: int
# Which nodes do we want to get scored tokens for, and which DSTs and DST configs should we use?
# This request spec refers to another request spec and grabs those values from it.
depends_on_spec_name: str
@immutable
class TokenPairAttributionRequest(InferenceRequest):
token_pair_attribution_request_spec: TokenPairAttributionRequestSpec
@immutable
class TokenPairAttributionResponseData(CamelCaseBaseModel):
response_data_type: ProcessingResponseDataType = (
ProcessingResponseDataType.TOKEN_PAIR_ATTRIBUTION_RESPONSE_DATA
)
# These two lists are parallel and have the same length. "None" values in top_tokens_attended_to_list
# indicate that token-pair attribution does not apply to the corresponding node.
node_indices: list[MirroredNodeIndex]
top_tokens_attended_to_list: list[TopTokensAttendedTo | None]
@immutable
class TokenPairAttributionResponse(InferenceResponse):
token_pair_attribution_response_data: TokenPairAttributionResponseData
########## Batching ##########
# Order from most to least specific
# See https://docs.pydantic.dev/1.10/usage/types/#unions
ProcessingRequestSpec = Union[
MultipleTopKDerivedScalarsRequestSpec,
DerivedScalarsRequestSpec,
DerivedAttentionScalarsRequestSpec,
ScoredTokensRequestSpec,
TokenPairAttributionRequestSpec,
]
# Order from most to least specific
# See https://docs.pydantic.dev/1.10/usage/types/#unions
ProcessingResponseData = Union[
MultipleTopKDerivedScalarsResponseData,
DerivedScalarsResponseData,
DerivedAttentionScalarsResponseData,
ScoredTokensResponseData,
TokenPairAttributionResponseData,
]
@immutable
class InferenceSubRequest(CamelCaseBaseModel):
inference_request_spec: InferenceRequestSpec
processing_request_spec_by_name: dict[str, ProcessingRequestSpec] = {}
@immutable
class InferenceResponseAndResponseDict(CamelCaseBaseModel):
inference_response: InferenceResponse
processing_response_data_by_name: dict[str, ProcessingResponseData] = {}
@immutable
class BatchedRequest(CamelCaseBaseModel):
inference_sub_requests: list[InferenceSubRequest]
@immutable
class BatchedResponse(CamelCaseBaseModel):
inference_sub_responses: list[InferenceResponseAndResponseDict]