neuron_explainer/models/model_context.py (306 lines of code) (raw):

import os from abc import ABC, abstractmethod from typing import Any import tiktoken import torch import torch.nn as nn from neuron_explainer.file_utils import CustomFileHandler from neuron_explainer.models import Transformer, TransformerConfig from neuron_explainer.models.inference_engine_type_registry import InferenceEngineType from neuron_explainer.models.model_component_registry import ( Dimension, LayerIndex, WeightLocationType, get_dimension_index_of_weight_location_type, weight_shape_by_location_type, ) from neuron_explainer.models.model_registry import get_standard_model_spec ALLOWED_SPECIAL_TOKENS = {"<|endoftext|>"} class InvalidTokenException(Exception): pass class ModelContext(ABC): def __init__(self, model_name: str, device: torch.device) -> None: self.model_name = model_name self.device = device # takes a WeightLocationType and optional layer # returns the tensor @abstractmethod def _get_weight_helper( self, location_type: WeightLocationType, layer: LayerIndex = None, device: torch.device | None = None, ) -> torch.Tensor: ... def get_weight( self, location_type: WeightLocationType, layer: LayerIndex = None, normalize_dim: Dimension | None = None, device: torch.device | None = None, ) -> torch.Tensor: """Returns the specified weights, with shape checking and optional normalization. Tensors returned by this method are not cloned, so please be sure not to perform in-place edits on them! """ assert ( location_type in weight_shape_by_location_type ), f"location_type_str {location_type} not found" weight = self._get_weight_helper( location_type=location_type, layer=layer, device=device or self.device ) if normalize_dim is not None: weight = nn.functional.normalize( weight, dim=get_dimension_index_of_weight_location_type(location_type, normalize_dim), ) weight_shape_spec = weight_shape_by_location_type[location_type] expected_shape = self.get_shape_from_spec(weight_shape_spec) assert ( weight.shape == expected_shape ), f"Expected shape {expected_shape} but got {weight.shape}" # We don't want to return tensors that have gradients enabled, so we detach. Ideally we'd # also clone since we don't want callers to inadvertently edit the weights, but doing so # uses a lot of memory, so instead we just ask politely in the docstring. return weight.detach() # get Encoding -> call this in the base class @abstractmethod def get_encoding(self) -> tiktoken.Encoding: ... def encode(self, string: str) -> list[int]: return self.get_encoding().encode(string, allowed_special=ALLOWED_SPECIAL_TOKENS) def decode_token(self, token: int) -> str: return self.get_encoding().decode([token]) def decode(self, token_list: list[int]) -> str: return self.get_encoding().decode(token_list) def encode_token_str(self, token_str: str) -> int: token_int_list = self.encode(token_str) if len(token_int_list) != 1: raise InvalidTokenException( f"'{token_str}' decoded to {token_int_list}; wanted exactly 1 token" ) return token_int_list[0] @abstractmethod def get_dim_size(self, model_dimension_spec: Dimension) -> int: ... def get_shape_from_spec(self, shape_spec: tuple[Dimension, ...]) -> tuple[int, ...]: expected_shape: tuple[int, ...] = tuple( self.get_dim_size(dimension_spec) if dimension_spec != Dimension.SINGLETON else 1 for dimension_spec in shape_spec ) return expected_shape @abstractmethod def get_or_create_model(self) -> Transformer: """Returns an instantiated model which can be used to run forward passes. The first call to this method results in a new model being created. Subsequent calls return the same cached model instance. """ ... def decode_token_list(self, token_list: list[int]) -> list[str]: return [self.decode_token(token=token) for token in token_list] def encode_token_str_list(self, token_str_list: list[str]) -> list[int]: return [self.encode_token_str(token_str=token_str) for token_str in token_str_list] @classmethod def from_model_type( cls, model_type: str, inference_engine_type: InferenceEngineType = InferenceEngineType.STANDARD, **kwargs: Any, ) -> "ModelContext": device = kwargs.pop("device", get_default_device()) if inference_engine_type == InferenceEngineType.STANDARD: return StandardModelContext(model_name=model_type, device=device, **kwargs) else: raise ValueError(f"Unsupported inference_engine_type {inference_engine_type}") @property def n_neurons(self) -> int: return self.get_dim_size(Dimension.MLP_ACTS) @property def n_attention_heads(self) -> int: return self.get_dim_size(Dimension.ATTN_HEADS) @property def n_layers(self) -> int: return self.get_dim_size(Dimension.LAYERS) @property def n_residual_stream_channels(self) -> int: return self.get_dim_size(Dimension.RESIDUAL_STREAM_CHANNELS) @property def n_vocab(self) -> int: return self.get_dim_size(Dimension.VOCAB_SIZE) @property def n_context(self) -> int: return self.get_dim_size(Dimension.MAX_CONTEXT_LENGTH) @abstractmethod def get_model_config_as_dict(self) -> dict[str, Any]: ... # Note: If you're seeing mysterious crashes while running on a MacBook, try switching from "mps" to # "cpu". def get_default_device() -> torch.device: # TODO: Figure out why test_interactive_model.py crashes on the "mps" backend, then remove # this workaround. is_pytest = "PYTEST_CURRENT_TEST" in os.environ if torch.cuda.is_available(): return torch.device("cuda", 0) elif torch.backends.mps.is_available() and not is_pytest: return torch.device("mps", 0) else: return torch.device("cpu") class StandardModelContext(ModelContext): def __init__(self, model_name: str, device: torch.device | None = None) -> None: if device is None: device = get_default_device() super().__init__(model_name=model_name, device=device) self._model_spec = get_standard_model_spec(self.model_name) self.load_path = self._model_spec.model_path self._config = TransformerConfig.load(f"{self.load_path}/config.json") # Once a transformer has been created via get_or_create_model, we cache it. Subsequent calls # to get_or_create_model return the cached instance. self._cached_transformer: Transformer | None = None @classmethod def from_model_type( cls, model_type: str, inference_engine_type: InferenceEngineType = InferenceEngineType.STANDARD, **kwargs: Any, ) -> ModelContext: # specifically a StandardModelContext, but to satisfy mypy assert ( inference_engine_type == InferenceEngineType.STANDARD ), "don't set a different inference_engine_type kwarg here" model_context = super().from_model_type( model_type=model_type, inference_engine_type=InferenceEngineType.STANDARD, **kwargs ) assert isinstance(model_context, StandardModelContext) return model_context def get_dim_size(self, model_dimension_spec: Dimension) -> int: # TODO(sbills): This should really be a match statement. dimension_by_dimension_spec: dict[Dimension, int] = { Dimension.MAX_CONTEXT_LENGTH: self._config.ctx_window, Dimension.RESIDUAL_STREAM_CHANNELS: self._config.d_model, Dimension.VOCAB_SIZE: self.get_encoding().n_vocab, Dimension.ATTN_HEADS: self._config.n_heads, Dimension.QUERY_AND_KEY_CHANNELS: self._config.d_head_qk, Dimension.VALUE_CHANNELS: self._config.d_head_v, Dimension.MLP_ACTS: self._config.d_ff, Dimension.MLP_ACTS: self._config.d_ff, Dimension.LAYERS: self._config.n_layers, } return dimension_by_dimension_spec[model_dimension_spec] def _get_weight_helper( self, location_type: WeightLocationType, layer: LayerIndex = None, device: torch.device | None = None, ) -> torch.Tensor: info_by_type: dict[WeightLocationType, dict] = { WeightLocationType.MLP_TO_HIDDEN: dict( part=f"xf_layers.{layer}.mlp.in_layer.weight", reshape="hr->rh", ), WeightLocationType.MLP_TO_RESIDUAL: dict( part=f"xf_layers.{layer}.mlp.out_layer.weight", reshape="rh->hr", ), WeightLocationType.EMBEDDING: dict( part="tok_embed.weight", ), WeightLocationType.UNEMBEDDING: dict( part="unembed.weight", reshape="vr->rv", ), WeightLocationType.POSITION_EMBEDDING: dict( part="pos_embed.weight", ), WeightLocationType.ATTN_TO_QUERY: dict( part=f"xf_layers.{layer}.attn.q_proj.weight", split=(0, self._config.n_heads), reshape="hqr->hrq", ), WeightLocationType.ATTN_TO_KEY: dict( part=f"xf_layers.{layer}.attn.k_proj.weight", split=(0, self._config.n_heads), reshape="hkr->hrk", ), WeightLocationType.ATTN_TO_VALUE: dict( part=f"xf_layers.{layer}.attn.v_proj.weight", split=(0, self._config.n_heads), reshape="hvr->hrv", ), WeightLocationType.ATTN_TO_RESIDUAL: dict( part=f"xf_layers.{layer}.attn.out_proj.weight", split=(1, self._config.n_heads), reshape="rhv->hvr", ), WeightLocationType.LAYER_NORM_GAIN_FINAL: dict( part="final_ln.weight", broadcast=True, ), WeightLocationType.LAYER_NORM_BIAS_FINAL: dict( part="final_ln.bias", ), WeightLocationType.LAYER_NORM_GAIN_PRE_ATTN: dict( part=f"xf_layers.{layer}.ln_1.weight", ), WeightLocationType.LAYER_NORM_GAIN_PRE_MLP: dict( part=f"xf_layers.{layer}.ln_2.weight", ), } info = info_by_type.get(location_type) if info is None: raise NotImplementedError(f"Unsupported weight location type: {location_type}") part = info["part"] split = info.get("split") reshape = info.get("reshape") if self._cached_transformer is None: with CustomFileHandler(f"{self.load_path}/model_pieces/{part}.pt", "rb") as f: weight = torch.load(f, map_location=device or self.device) else: weight = self._cached_transformer.state_dict()[part].to(device or self.device) if split is not None: (dim_split, n_split) = split w_shape = list(weight.shape) w_shape_new = ( w_shape[:dim_split] + [n_split, w_shape[dim_split] // n_split] + w_shape[dim_split + 1 :] ) weight = weight.reshape(*w_shape_new) if reshape is not None: weight = torch.einsum(reshape, weight) # Some tensors are sometimes stored with a subset of dimensions and then broadcasted in the model # E.g. the final layer norm gain is stored as a scalar # Broadcast flag indicates that we should broadcast them to the expected shape broadcast = info.get("broadcast") if broadcast is True: expected_shape = self.get_shape_from_spec(weight_shape_by_location_type[location_type]) weight = weight.expand(expected_shape) return weight def get_or_create_model( self, device: torch.device | None = None, simplify: bool = False, ) -> Transformer: if self._cached_transformer is None: self._cached_transformer = Transformer.load( self.load_path, simplify=simplify, device=device or self.device ) return self._cached_transformer def get_encoding(self) -> tiktoken.Encoding: return tiktoken.get_encoding(self._config.enc) def get_model_config_as_dict(self) -> dict[str, Any]: return self._config.to_dict() class StubModelContext(ModelContext): # TODO: maybe make a unified interface for the Config objects of ModelContext objects, and # have this be a StubConfig instead of a StubContext """This is a fake model context object for use in testing. It just works as a holder for a mapping from model dimension to size (int).""" def __init__( self, size_by_model_dimension_spec: dict[Dimension, int], ): super().__init__(model_name="stub", device=torch.device("cpu")) self._size_by_model_dimension_spec = size_by_model_dimension_spec def _get_weight_helper( self, location_type: WeightLocationType, layer: LayerIndex = None, device: torch.device | None = None, ) -> torch.Tensor: raise NotImplementedError def get_encoding(self) -> tiktoken.Encoding: raise NotImplementedError def get_or_create_model(self) -> Transformer: raise NotImplementedError def get_model_config_as_dict(self) -> dict[str, Any]: raise NotImplementedError def get_dim_size(self, model_dimension_spec: Dimension) -> int: if model_dimension_spec in self._size_by_model_dimension_spec: return self._size_by_model_dimension_spec[model_dimension_spec] else: raise NotImplementedError # TODO: make this robust to whether the transformer is 'simplified' in our terminology # once the .simplify() operation is extended to cover final layer norm gain def get_unembedding_with_ln_gain(model_context: ModelContext) -> torch.Tensor: """ returns an unembedding matrix multiplied by the layer norm gain (a d_model-dimensional vector) for the final layer """ Unemb_without_ln_gain = model_context.get_weight( location_type=WeightLocationType.UNEMBEDDING, device=model_context.device, ) ln_gain_final = model_context.get_weight( location_type=WeightLocationType.LAYER_NORM_GAIN_FINAL, device=model_context.device, ) return torch.einsum("ov,o->ov", Unemb_without_ln_gain, ln_gain_final) def get_embedding(model_context: ModelContext) -> torch.Tensor: """ returns an embedding matrix. Note that there is no layer norm in between the embedding tensor and the residual stream """ return model_context.get_weight( location_type=WeightLocationType.EMBEDDING, device=model_context.device, )