neuron_explainer/models/transformer.py (522 lines of code) (raw):

import json import os.path as osp import pickle from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from dataclasses import asdict, dataclass from functools import cache from typing import Any, Self, Union import numpy as np import tiktoken import torch import torch.nn as nn from torch import Tensor from torch.distributions.categorical import Categorical from torch.utils.checkpoint import checkpoint from neuron_explainer.file_utils import CustomFileHandler, copy_to_local_cache, file_exists from neuron_explainer.models.hooks import ( AttentionHooks, MLPHooks, NormalizationHooks, TransformerHooks, ) # for static analysis Device = Union[torch.device, str] # NOTE: some code from this file related to attention, MLP, and layernorm operations is copy-pasted in # neuron_explainer/activations/derived_scalars/reconstituted.py; if those operations change here, they should correspondingly # be changed in that file. class SerializableDataclass: def to_dict(self) -> dict: return asdict(self) @classmethod def from_dict(cls, d) -> Self: return cls(**d) def save(self, path: str) -> None: if path.endswith((".pkl", ".pickle")): with CustomFileHandler(path, "wb") as f: pickle.dump(self.to_dict(), f) elif path.endswith(".json"): with CustomFileHandler(path, "w") as f: json.dump(self.to_dict(), f) else: raise ValueError(f"Unknown file extension for {path}") @classmethod def load(cls, path: str) -> Self: if path.endswith((".pkl", ".pickle")): with CustomFileHandler(path, "rb") as f: return cls.from_dict(pickle.load(f)) elif path.endswith(".json"): with CustomFileHandler(path, "r") as f: return cls.from_dict(json.load(f)) else: raise ValueError(f"Unknown file extension for {path}") @dataclass class TransformerConfig(SerializableDataclass): enc: str = "gpt2" ctx_window: int = 1024 d_model: int = 256 n_layers: int = 2 # attn m_attn: float = 1 n_heads: int = 8 # mlp m_mlp: float = 4 @property def d_ff(self) -> int: return int(self.d_model * self.m_mlp) @property def d_attn_qk(self) -> int: return int(self.d_model * self.m_attn) @property def d_attn_v(self) -> int: return int(self.d_model * self.m_attn) @property def d_head_qk(self) -> int: return safe_div(self.d_attn_qk, self.n_heads) @property def d_head_v(self) -> int: return safe_div(self.d_attn_v, self.n_heads) def default_device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def safe_div(numerator: int, denominator: int) -> int: assert numerator % denominator == 0 return numerator // denominator # ==================== # Attention utilities # ==================== @cache def causal_attn_mask(size: int, device: Device = "cpu") -> Tensor: return torch.tril(torch.ones(size, size)).bool().to(device) def split_heads(Z: Tensor, n_heads: int) -> Tensor: batch, seq, d_attn = Z.shape return Z.reshape(batch, seq, n_heads, d_attn // n_heads) def merge_heads(Z: Tensor) -> Tensor: batch, seq, n_heads, d_head = Z.shape return Z.reshape(batch, seq, n_heads * d_head) # =================================== # MLP utilities # =================================== def gelu(x: Tensor) -> Tensor: return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) # return x * torch.sigmoid(1.702 * x) # ======================================== # Sampling, padding and related utilities # ======================================== def prep_input_and_right_pad_for_forward_pass( X: list[list[int]], device: Device = "cpu" ) -> tuple[Tensor, Tensor]: # Helper function. The two tensors returned by this function are suitable to be passed to # Transformer.forward. return prep_input_and_pad(X, "right", device) def prep_input_and_pad( X: list[list[int]], pad_side: str, device: Device = "cpu" ) -> tuple[Tensor, Tensor]: # X is a list of tokenized prompts; prompts may have unequal lengths. This function will # left-pad X by putting "-1" in all the slots where a prompt is shorter than the longest prompt. # Then convert X into a tensor of int tokens. Then build the pad tensor by looking for the # "-1"s. Then fill the "-1"s in X with "0"s so the embedding layer doesn't get upset. max_len = max([len(prompt) for prompt in X]) def pad(x): padding = [-1] * (max_len - len(x)) if pad_side == "left": return padding + x elif pad_side == "right": return x + padding else: raise ValueError(f"pad_side must be 'left' or 'right', not {pad_side}") X_tensor = torch.LongTensor([pad(prompt) for prompt in X]).to(device) pad = X_tensor == -1 X_tensor = torch.where(X_tensor == -1, 0, X_tensor) return X_tensor, pad def prep_pos_from_pad_and_prev_lens(pad: Tensor, prev_lens: Tensor) -> Tensor: # pad has shape b x s, prev_lens has shape b x 1. # For position embedding, we need a tensor of shape (b x s) whose # entries are the positions of X in the sequence. When sampling with # prompts of unequal length, X is left padded with pad tokens. The # position tensor needs to take that into account. pos = torch.logical_not(pad).long().cumsum(dim=-1) - 1 pos = torch.where(pos == -1, 0, pos) return pos + prev_lens def nucleus_sample(logits: Tensor, top_p: float) -> Tensor: # top_p in [0,1] is the total probability mass of top outputs. # based on https://nn.labml.ai/sampling/nucleus.html # input shape: [..., n_vocab] -> output shape: [...] sorted_logits, idxs = torch.sort(logits, dim=-1, descending=True) sorted_probs = torch.softmax(sorted_logits, dim=-1) cum_probs = torch.cumsum(sorted_probs, dim=-1) # logic to ensure there is always at least one token with nonzero # probability when selecting nucleus. p0 = cum_probs[..., 0] top_p = torch.where(p0 > top_p, p0, top_p)[..., None] # sampling do_not_sample = cum_probs > top_p sorted_logits = sorted_logits.masked_fill(do_not_sample, float("-inf")) dist = Categorical(logits=sorted_logits) samples = dist.sample() tokens = idxs.gather(-1, samples.unsqueeze(-1)).squeeze(-1) return tokens # =============== # Layer Norm # =============== class Norm(nn.Module): """LayerNorm reimplementation with hooks.""" def __init__( self, size: int, eps: float = 1e-5, device: Device | None = None, dtype: torch.dtype | None = None, ): super().__init__() kwargs = {"device": device, "dtype": dtype} self.size = size self.weight = nn.Parameter(torch.empty(size, **kwargs)) # type: ignore[arg-type] self.bias = nn.Parameter(torch.empty(size, **kwargs)) # type: ignore[arg-type] self.eps = eps def forward(self, x: Tensor, hooks: NormalizationHooks = None) -> Tensor: if hooks is None: hooks = NormalizationHooks() # always do norm in fp32 orig_dtype = x.dtype x = x.float() x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length] x = hooks.post_mean_subtraction(x) scale = torch.sqrt((x**2).mean(dim=-1, keepdim=True) + self.eps) scale = hooks.scale(scale) x = x / scale x = hooks.post_scale(x) ret = x * self.weight + self.bias return ret.to(orig_dtype) def apply_layernorm_foldin(ln: Norm, linears: list[nn.Linear]) -> None: # folds in a layernorm weight/bias into the next linear layer. # ln(x) = W_ln * (x - x.mean())/(x.std()) + b_ln # linear(ln(x)) = W_linear * (W_ln * (x - x.mean())/(x.std()) + b_ln) + b_linear W_ln = ln.weight.float() b_ln = ln.bias.float() for linear in linears: W_linear = linear.weight.float() b_linear = linear.bias.float() W_composed = W_linear * W_ln[None, :] b_composed = None b_composed = b_linear + W_linear @ b_ln # should only copy after new weights are calculated linear.weight.data.copy_(W_composed) linear.bias.data.copy_(b_composed) ln.weight.data[:] = 1 ln.bias.data[:] = 0 # =========================================== # Attention layers and associated components # =========================================== @dataclass class KeyValueCache: """KV cache to save on compute""" K_cache: Tensor | None = None # b x s_old x d V_cache: Tensor | None = None # b x s_old x d pad_cache: Tensor | None = None # b x s_old def update(self, K: Tensor, V: Tensor, pad: Tensor): # K, V have shape: b x (s_new - s_old) x d # pad has shape: b x (s_new - s_old) new = self.K_cache is None self.K_cache = K if new else torch.cat([self.K_cache, K], dim=1) self.V_cache = V if new else torch.cat([self.V_cache, V], dim=1) self.pad_cache = pad if new else torch.cat([self.pad_cache, pad], dim=1) return self.K_cache, self.V_cache, self.pad_cache class MultiHeadedDotProductSelfAttention(nn.Module): """A configurable multi-headed dot product attention layer.""" def __init__( self, cfg: TransformerConfig, layer_idx: int, device: Device | None = None, dtype: torch.dtype | None = None, ): super().__init__() self.n_heads = cfg.n_heads # make layers kwargs = {"device": device, "dtype": dtype} self.q_proj = nn.Linear(cfg.d_model, cfg.d_attn_qk, **kwargs) self.k_proj = nn.Linear(cfg.d_model, cfg.d_attn_qk, **kwargs) self.v_proj = nn.Linear(cfg.d_model, cfg.d_attn_v, **kwargs) self.out_proj = nn.Linear(cfg.d_attn_v, cfg.d_model, **kwargs) self.qk_scale = 1 / np.sqrt(np.sqrt(cfg.d_head_qk)) self.cfg = cfg def forward( self, X: Tensor, kv_cache: KeyValueCache | None = None, pad: Tensor | None = None, hooks: AttentionHooks = AttentionHooks(), ) -> tuple[Tensor, KeyValueCache]: Q = self.q_proj(X) K = self.k_proj(X) V = self.v_proj(X) # update KV cache if kv_cache is None: kv_cache = KeyValueCache() K, V, pad = kv_cache.update(K, V, pad) # split apart heads, rescale QK Q = split_heads(Q, self.n_heads) * self.qk_scale Q = hooks.q(Q) # bshd K = split_heads(K, self.n_heads) * self.qk_scale K = hooks.k(K) # bshd V = split_heads(V, self.n_heads) V = hooks.v(V) # bshd # useful for calculations below n_queries, n_keys = Q.shape[1], K.shape[1] # softmax multi-headed dot product attention pre_softmax = torch.einsum("bqhd,bkhd -> bhqk", Q, K) # apply causal attention mask M = causal_attn_mask(n_keys, device=X.device) M = M[None, None, -n_queries:] # make M broadcastable to batch, head pre_softmax = pre_softmax.masked_fill(torch.logical_not(M), float("-inf")) # apply pad mask if pad is not None and torch.any(pad): # we only mask out pad tokens for non-pad query tokens # (because masking all pad tokens => empty rows => NaNs later) pad_mask = torch.bitwise_xor(pad[:, None, :], pad[:, :, None]) # make pad broadcastable on head dim, and slice for current queries only pad_mask = pad_mask[:, None, -n_queries:] # apply pad mask pre_softmax = pre_softmax.masked_fill(pad_mask, float("-inf")) pre_softmax = torch.einsum("bhqk->bqkh", pre_softmax) pre_softmax = hooks.qk_logits(pre_softmax) pre_softmax = pre_softmax.float() # for numerical stability if hooks.qk_softmax_denominator.is_empty(): attn = torch.softmax(pre_softmax, dim=-2) else: # factor out softmax in order to hook pre_softmax_max = torch.max(pre_softmax, -2, keepdim=True)[0].detach() numerator = torch.exp(pre_softmax - pre_softmax_max) denominator = numerator.sum(dim=-2, keepdim=True) denominator = hooks.qk_softmax_denominator(denominator) attn = numerator / denominator attn = attn.to(Q.dtype) attn = hooks.qk_probs(attn) out = torch.einsum("bqkh,bkhd->bqhd", attn, V) out = hooks.v_out(out) out = merge_heads(out) # concatenate results from all heads # final output projection return self.out_proj(out), kv_cache # ===================================== # MLP layers and associated components # ===================================== class MLP(nn.Module): """An MLP for a transformer is a simple two-layer network.""" def __init__( self, cfg: TransformerConfig, device: Device | None = None, dtype: torch.dtype | None = None ): super().__init__() kwargs = {"device": device, "dtype": dtype} self.in_layer = nn.Linear(cfg.d_model, cfg.d_ff, **kwargs) self.out_layer = nn.Linear(cfg.d_ff, cfg.d_model, **kwargs) self.act = gelu def forward(self, X: Tensor, hooks: MLPHooks = MLPHooks()) -> Tensor: pre = self.in_layer(X) pre = hooks.pre_act(pre) a = self.act(pre) a = hooks.post_act(a, out_layer=self.out_layer) out = self.out_layer(a) return out # ============= # Transformers # ============= class TransformerLayer(nn.Module): def __init__( self, cfg: TransformerConfig, layer_idx: int, device: Device | None = None, dtype: torch.dtype | None = None, ): super().__init__() kwargs = {"device": device, "dtype": dtype} self.cfg = cfg self.attn = MultiHeadedDotProductSelfAttention(cfg, layer_idx, **kwargs) self.mlp = MLP(cfg, **kwargs) self.ln_1 = Norm(cfg.d_model, **kwargs) self.ln_2 = Norm(cfg.d_model, **kwargs) self.layer_idx = layer_idx def simplify(self) -> None: ln_1_linears: list[Any] = [ self.attn.q_proj, self.attn.k_proj, self.attn.v_proj, ] apply_layernorm_foldin(self.ln_1, ln_1_linears) ln_2_linears: list[Any] = [self.mlp.in_layer] apply_layernorm_foldin(self.ln_2, ln_2_linears) def attn_block( self, X: Tensor, kv_cache: KeyValueCache | None, pad: Tensor | None, hooks: TransformerHooks ) -> Tensor: ln_X = self.ln_1(X, hooks.resid.torso.ln_attn) ln_X = hooks.resid.torso.post_ln_attn(ln_X) attn_delta, kv_cache = self.attn(ln_X, kv_cache, pad, hooks.attn) attn_delta = hooks.resid.torso.delta_attn(attn_delta) return attn_delta, kv_cache def mlp_block(self, X: Tensor, hooks: TransformerHooks) -> Tensor: ln_X = self.ln_2(X, hooks.resid.torso.ln_mlp) ln_X = hooks.resid.torso.post_ln_mlp(ln_X) mlp_delta = self.mlp(ln_X, hooks.mlp) mlp_delta = hooks.resid.torso.delta_mlp(mlp_delta) return mlp_delta def forward( self, X: Tensor, kv_cache: KeyValueCache | None = None, pad: Tensor | None = None, hooks: TransformerHooks = TransformerHooks(), ) -> tuple[Tensor, KeyValueCache]: attn_delta, kv_cache = self.attn_block(X, kv_cache, pad, hooks) X = X + attn_delta X = hooks.resid.torso.post_attn(X) mlp_delta = self.mlp_block(X, hooks) X = X + mlp_delta X = hooks.resid.torso.post_mlp(X) return X, kv_cache class HiddenState: """A hidden state for a transformer. Tracks prompt lengths and KV caches.""" def __init__(self, n_layers: int): self.prev_lens = 0 self.kv_caches = [None for _ in range(n_layers)] def set_prev_lens(self, prev_lens) -> None: self.prev_lens = prev_lens def __getitem__(self, idx: int): return self.kv_caches[idx] def __setitem__(self, idx: int, value: KeyValueCache | None): self.kv_caches[idx] = value class Transformer(nn.Module): def __init__( self, cfg: TransformerConfig, # recomputing is optional, and it trades off compute for memory. recompute: bool = False, device: Device | None = None, dtype: torch.dtype | None = None, ): super().__init__() self.cfg = cfg self.enc = tiktoken.get_encoding(self.cfg.enc) self.n_vocab = self.enc.n_vocab self.recompute = recompute self.dtype = dtype # build network kwargs = {"device": device, "dtype": dtype} self.tok_embed = nn.Embedding(self.n_vocab, cfg.d_model, **kwargs) self.pos_embed = nn.Embedding(cfg.ctx_window, cfg.d_model, **kwargs) self.xf_layers = nn.ModuleList( [TransformerLayer(cfg, idx, **kwargs) for idx in range(cfg.n_layers)] ) self.final_ln = Norm(cfg.d_model, **kwargs) self.unembed = nn.Linear(cfg.d_model, self.n_vocab, bias=False, **kwargs) def simplify(self): for xf_layer in self.xf_layers: xf_layer.simplify() # NOTE: we can't fold layer norm into unembedding layer # because it has no bias # apply_layernorm_foldin(self.final_ln, [self.unembed]) @property def device(self) -> Device: return next(self.parameters()).device def set_recompute(self, recompute: bool) -> None: self.recompute = recompute def forward( self, tokens: Tensor, H: HiddenState | None = None, pad: Tensor | None = None, hooks: TransformerHooks = TransformerHooks(), ) -> tuple[Tensor, HiddenState]: """ Forward pass through the transformer! During evaluation or first forward pass in sampling: X is expected to be a [batch_size x sequence_length]-shaped LongTensor of encoded prompts. H is expected to be None. pad is a [batch_size x sequence_length]-shaped boolean Tensor. "1"s mean "ignore this token". This parameter must be set if not all encoded prompts in X have the same length. Note that activations observed by hooks will include padded values. During sampling after first forward pass: X is expected to be the new part of the sequences (eg most recently sampled tokens). H is expected to have KV-caches of all Keys and Values for prior tokens. pad is expected to be None (new tokens are not pad tokens). Returns a tuple containing the resulting logits tensor and a new hidden state consisting of a KV cache. """ X, H, pad, hooks = self.run_embed(tokens, H, pad, hooks) X, H, pad, hooks = self.run_torso(X, H, pad, hooks) return self.run_unembed(X, H, hooks) def run_embed( self, tokens: Tensor, H: HiddenState | None = None, pad: Tensor | None = None, hooks: TransformerHooks = TransformerHooks(), ) -> tuple[Tensor, HiddenState, Tensor | None, TransformerHooks]: assert tokens.dtype == torch.long, "tokens must be sequences of tokens." if H is None: H = HiddenState(self.cfg.n_layers) if pad is None: pad = torch.zeros_like(tokens, dtype=torch.bool) # embedding X = self.tok_embed(tokens) # position encoding logic to support sampling with prompts of unequal length. pos = prep_pos_from_pad_and_prev_lens(pad, H.prev_lens) seq_lens = (pos[:, -1] + 1).unsqueeze(-1) assert all( seq_lens <= self.cfg.ctx_window ), f"sequences must fit in the context window {self.cfg.ctx_window}." H.set_prev_lens(seq_lens) X = X + self.pos_embed(pos) X = hooks.resid.post_emb(X) return X, H, pad, hooks def run_torso( self, X: Tensor, H: HiddenState | None, pad: Tensor | None, hooks: TransformerHooks, ) -> tuple[Tensor, HiddenState, Tensor | None, TransformerHooks]: # transformer torso for i, xf_layer in enumerate(self.xf_layers): hooks_layer_i = deepcopy(hooks).bind(layer=i) if self.recompute: X, H[i] = checkpoint(xf_layer, X, H[i], pad, hooks_layer_i) else: X, H[i] = xf_layer(X, H[i], pad, hooks_layer_i) return X, H, pad, hooks def run_ln_f( self, X: Tensor, H: HiddenState | None, hooks: TransformerHooks, ) -> tuple[Tensor, HiddenState, TransformerHooks]: X = self.final_ln(X, hooks.resid.ln_f) X = hooks.resid.post_ln_f(X) return X, H, hooks def run_unembed( self, X: Tensor, H: HiddenState | None, hooks: TransformerHooks, ) -> tuple[Tensor, HiddenState]: # unembedding X, H, hooks = self.run_ln_f(X, H, hooks) X = self.unembed(X) X = hooks.logits(X) return X, H def sample( self, prompts: str | list[str] | list[int] | list[list[int]], num_tokens: int = 5, temperature: float = 1.0, top_p: float | None = None, hooks: TransformerHooks = TransformerHooks(), ) -> dict[str, Any]: """ Sampling with the transformer! If top_p is set, then nucleus sampling is used. Otherwise, the sampling will be Categorical. If temperature=0, sampling is deterministic (and top_p is ignored). (Warning: when using torch.use_deterministic_algorithms(True), nucleus sampling will throw an error. It depends on torch.cumsum, which unfortunately has no deterministic implementation in torch.) Output is a dict {'tokens': list[list[int]], 'strings': list[str]} """ prompts = [prompts] if isinstance(prompts, str) else prompts if isinstance(prompts[0], str): X: list[list[int]] = [self.enc.encode(prompt) for prompt in prompts] elif isinstance(prompts[0], int): X = [prompts] else: X = prompts X, pad = prep_input_and_pad(X, "left", self.device) H = None beta = 1 / max(temperature, 1e-10) out = { "tokens": [[] for _ in prompts], "strings": ["" for _ in prompts], } # sampling loop for _ in range(num_tokens): with torch.no_grad(): # get logits Y, H = self.forward(X, H, pad, hooks=hooks) logits = Y[:, -1] * beta # sampling only works if logits are floats logits = logits.float() # perform sampling if temperature == 0: tokens = torch.argmax(logits, dim=-1) elif top_p is not None: tokens = nucleus_sample(logits, top_p) else: tokens = Categorical(logits=logits).sample() X, pad = tokens.unsqueeze(-1), None for batch_idx, token in enumerate(tokens): out["tokens"][batch_idx].append(token.item()) out["strings"][batch_idx] += self.enc.decode([token.item()]) return out @classmethod def load( cls, name_or_path: str, device: Device | None = None, dtype: torch.dtype | None = None, simplify: bool = False, simplify_kwargs: dict[str, Any] | None = None, ) -> "Transformer": if name_or_path.startswith("https://"): path = name_or_path else: path = f"https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/{name_or_path.replace('-', '/')}" xf = cls.from_checkpoint( path, device=device, dtype=dtype, ) if simplify: if simplify_kwargs is None: simplify_kwargs = {} xf.simplify(**simplify_kwargs) return xf def save_checkpoint( self, path: str, ) -> None: self.cfg.save(osp.join(path, "config.json")) pieces_path = osp.join(path, "model_pieces") for k, v in self.state_dict().items(): with CustomFileHandler(osp.join(pieces_path, f"{k}.pt"), "wb") as f: torch.save(v, f) def load_state_from_checkpoint( self, path: str, device: Device | None = None, dtype: torch.dtype | None = None ): pieces_path = osp.join(path, "model_pieces") piece_names = set(self.state_dict().keys()) piece_files = [f"{k}.pt" for k in piece_names] if dtype is not None: assert isinstance(dtype, torch.dtype), "Must provide valid dtype." device = device or self.device with ThreadPoolExecutor(max_workers=50) as executor: k_to_future = { fname[: -len(".pt")]: executor.submit( _load_piece, osp.join(pieces_path, fname), device, dtype ) for fname in piece_files } d = {k: future.result() for k, future in k_to_future.items()} self.load_state_dict(d) @classmethod def from_checkpoint( cls, path: str, device: Device | None = None, dtype: torch.dtype | None = None ) -> "Transformer": if device is None: device = default_device() cfg = TransformerConfig.load(osp.join(path, "config.json")) xf = cls(cfg, device=device, dtype=dtype) xf.load_state_from_checkpoint(path, device=device, dtype=dtype) return xf def _load_piece( file_path: str, device: Device, dtype: torch.dtype | None ) -> tuple[str, torch.Tensor]: disk_cache_path = osp.join( "/tmp/neuron-explainer-model-pieces-cache", file_path.replace("https://", "") ) if not file_exists(disk_cache_path): copy_to_local_cache(file_path, disk_cache_path) with CustomFileHandler(disk_cache_path, "rb") as f: t = torch.load(f, map_location=device) if dtype is not None: t = t.to(dtype) return t