neuron_explainer/models/hooks.py (230 lines of code) (raw):
from collections import OrderedDict
from copy import deepcopy
import torch
class Hooks:
"""A callable that is a sequence of callables"""
def __init__(self):
self._hooks = []
self.bound_kwargs = {}
def __call__(self, x, *args, **kwargs):
for hook in self._hooks:
x = hook(x, *args, **kwargs, **self.bound_kwargs)
return x
def append(self, hook):
self._hooks.append(hook)
return self
def bind(self, **kwargs):
for key, value in kwargs.items():
if key in self.bound_kwargs:
raise ValueError(f"Key {key} already bound")
self.bound_kwargs[key] = value
return self
def unbind(self, keys: list):
for key in keys:
del self.bound_kwargs[key]
return self
def __repr__(self, indent=0, name=None):
import inspect
indent_str = " " * indent
full_name = f"{name}" if name is not None else self.name
if self.bound_kwargs:
full_name += f" {self.bound_kwargs}"
if self.is_empty():
return f"{indent_str}{full_name}"
def hook_repr(hook):
if "indent" in inspect.signature(hook.__class__.__repr__).parameters:
return hook.__repr__(indent=indent + 4)
else:
return indent_str + " " * 4 + repr(hook)
hooks_repr = "\n".join(f"{hook_repr(hook)}" for hook in self._hooks)
return f"{indent_str}{full_name}\n{hooks_repr}"
@property
def name(self):
return self.__class__.__name__
def is_empty(self):
return len(self._hooks) == 0
# takes a gradient hook and makes into a forward pass hook
def grad_hook_wrapper(grad_hook):
def fwd_hook(act, *args, **kwargs):
class _IdentityWithGradHook(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
return tensor
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_hook(grad_output, *args, **kwargs)
return grad_output
return _IdentityWithGradHook.apply(act)
return fwd_hook
class HookCollection:
def __init__(self):
self.all_hooks = OrderedDict()
def bind(self, **kwargs):
for hook in self.all_hooks.values():
hook.bind(**kwargs)
return self
def unbind(self, keys):
for hook in self.all_hooks.values():
hook.unbind(keys)
return self
def append_all(self, hook):
for hooks in self.all_hooks.values():
try:
hooks.append_all(hook)
except AttributeError:
hooks.append(hook)
return self
def append_to_path(self, hook_location_name, hook):
"""
Adds a hook to a location in a nested hook collection with a dot-separated name.
e.g. `self.append_to_path("resid.torso.post_mlp.fwd", hook)` adds `hook` to
`self.all_hooks["resid"].all_hooks["torso"].all_hooks["post_mlp"].all_hooks["fwd"]`
"""
hook_location_parts = hook_location_name.split(".", 1) # split at first dot, if present
top_level_location = hook_location_parts[0]
assert top_level_location in self.all_hooks
if len(hook_location_parts) == 1: # no dot in path
self.all_hooks[top_level_location].append(hook)
else: # at least one dot in path -> split outputs two parts
sub_location = hook_location_parts[1]
self.all_hooks[top_level_location].append_to_path(sub_location, hook)
return self
def __deepcopy__(self, memo):
# can't use deepcopy because of __getattr__
new = self.__class__()
new.all_hooks = deepcopy(self.all_hooks)
return new
def add_subhooks(self, name, subhooks):
self.all_hooks[name] = subhooks
return self
def __getattr__(self, name):
if name in self.all_hooks:
return self.all_hooks[name]
else:
raise AttributeError(f"HookCollection has no attribute {name}")
def __repr__(self, indent=0, name=None):
indent_str = " " * indent
full_name = f"{name}" if name is not None else self.__class__.__name__
prefix = f"{name}." if name is not None else ""
hooks_repr = "\n".join(
hook.__repr__(indent + 4, f"{prefix}{hook_name}")
for hook_name, hook in self.all_hooks.items()
)
return f"{indent_str}{full_name}\n{hooks_repr}"
def is_empty(self):
return all(hook.is_empty() for hook in self.all_hooks.values())
class FwdBwdHooks(HookCollection):
def __init__(self):
super().__init__()
# By default, all grad hooks are applied after all forward hooks. This way,
# the gradients are given for the final "hooked" output of the forward pass.
# If you want gradients for an intermediate output, you should simply
# append_fwd(from_grad_hook(hook)) at the appropriate time.
self.add_subhooks("fwd", Hooks())
self.add_subhooks("bwd", WrapperHooks(wrapper=grad_hook_wrapper))
self.add_subhooks("fwd2", Hooks())
def append_fwd(self, fwd_hook):
self.fwd.append(fwd_hook)
return self
def append_bwd(self, bwd_hook):
self.bwd.append(bwd_hook)
return self
def append_fwd2(self, fwd2_hook):
self.fwd2.append(fwd2_hook)
return self
def __call__(self, x, *args, **kwargs):
# hooks into fwd, then bwd, then fwd2
x = self.fwd(x, *args, **kwargs)
x = self.bwd(x, *args, **kwargs)
x = self.fwd2(x, *args, **kwargs)
return x
class MLPHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("pre_act", FwdBwdHooks())
self.add_subhooks("post_act", FwdBwdHooks())
class NormalizationHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("post_mean_subtraction", FwdBwdHooks())
self.add_subhooks("scale", FwdBwdHooks())
self.add_subhooks("post_scale", FwdBwdHooks())
class AttentionHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("q", FwdBwdHooks())
self.add_subhooks("k", FwdBwdHooks())
self.add_subhooks("v", FwdBwdHooks())
self.add_subhooks("qk_logits", FwdBwdHooks())
self.add_subhooks("qk_softmax_denominator", FwdBwdHooks())
self.add_subhooks("qk_probs", FwdBwdHooks())
self.add_subhooks("v_out", FwdBwdHooks()) # pre-final projection
class ResidualStreamTorsoHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("post_ln_attn", FwdBwdHooks())
self.add_subhooks("ln_attn", NormalizationHooks())
self.add_subhooks("delta_attn", FwdBwdHooks())
self.add_subhooks("post_attn", FwdBwdHooks())
self.add_subhooks("ln_mlp", NormalizationHooks())
self.add_subhooks("post_ln_mlp", FwdBwdHooks())
self.add_subhooks("delta_mlp", FwdBwdHooks())
self.add_subhooks("post_mlp", FwdBwdHooks())
class ResidualStreamHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("post_emb", FwdBwdHooks())
self.add_subhooks("torso", ResidualStreamTorsoHooks())
self.add_subhooks("ln_f", NormalizationHooks())
self.add_subhooks("post_ln_f", FwdBwdHooks())
class TransformerHooks(HookCollection):
def __init__(self):
super().__init__()
self.add_subhooks("mlp", MLPHooks())
self.add_subhooks("attn", AttentionHooks())
self.add_subhooks("resid", ResidualStreamHooks())
self.add_subhooks("logits", FwdBwdHooks())
class WrapperHooks(Hooks):
def __init__(self, wrapper):
self.wrapper = wrapper
super().__init__()
def append(self, fn):
self._hooks.append(self.wrapper(fn))
class ConditionalHooks(Hooks):
def __init__(self, condition):
self.condition = condition
super().__init__()
def __call__(self, x, *args, **kwargs):
cond = self.condition(x, *args, **kwargs)
if cond:
x = super().__call__(x, *args, **kwargs)
return x
class AtLayers(ConditionalHooks):
def __init__(self, at_layers: int | list[int]):
if isinstance(at_layers, int):
at_layers = [at_layers]
self.at_layers = at_layers
def at_layers_condition(x, *, layer, **kwargs):
return layer in at_layers
super().__init__(condition=at_layers_condition)
@property
def name(self):
return f"{self.__class__.__name__}({self.at_layers})"
class AutoencoderHooks(HookCollection):
"""
Hooks into the hidden dimension of an autoencoder.
"""
def __init__(self, encode, decode, add_error=False):
super().__init__()
# hooks
self.add_subhooks("latents", FwdBwdHooks())
self.add_subhooks("reconstruction", FwdBwdHooks())
self.add_subhooks("error", FwdBwdHooks())
# autoencoder functions
self.encode = encode
self.decode = decode
# if add_error is True, add the error to the reconstruction.
self.add_error = add_error
def __call__(self, x, *args, **kwargs):
latents = self.encode(x)
if self.add_error:
# Here, the latents are cloned twice:
# - the first clone is passed through `self.latents` and `self.reconstruction`
# - the second clone is passed through `self.error`
latents_to_hook = latents.clone()
latents_to_error = latents.clone()
else:
latents_to_hook = latents
latents_to_hook = self.latents(latents_to_hook, *args, **kwargs) # call hooks
reconstruction = self.decode(latents_to_hook)
reconstruction = self.reconstruction(reconstruction, *args, **kwargs) # call hooks
if self.add_error:
error = x - self.decode(latents_to_error)
error = self.error(error, *args, **kwargs) # call hooks
return reconstruction + error
else:
error = x - reconstruction
error = self.error(error, *args, **kwargs) # call hooks
return reconstruction
def __deepcopy__(self, memo):
# can't use deepcopy because of __getattr__
new = self.__class__(self.encode, self.decode, self.add_error)
new.all_hooks = deepcopy(self.all_hooks)
return new