neuron_explainer/models/transformer.py (522 lines of code) (raw):
import json
import os.path as osp
import pickle
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import cache
from typing import Any, Self, Union
import numpy as np
import tiktoken
import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.utils.checkpoint import checkpoint
from neuron_explainer.file_utils import CustomFileHandler, copy_to_local_cache, file_exists
from neuron_explainer.models.hooks import (
AttentionHooks,
MLPHooks,
NormalizationHooks,
TransformerHooks,
)
# for static analysis
Device = Union[torch.device, str]
# NOTE: some code from this file related to attention, MLP, and layernorm operations is copy-pasted in
# neuron_explainer/activations/derived_scalars/reconstituted.py; if those operations change here, they should correspondingly
# be changed in that file.
class SerializableDataclass:
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, d) -> Self:
return cls(**d)
def save(self, path: str) -> None:
if path.endswith((".pkl", ".pickle")):
with CustomFileHandler(path, "wb") as f:
pickle.dump(self.to_dict(), f)
elif path.endswith(".json"):
with CustomFileHandler(path, "w") as f:
json.dump(self.to_dict(), f)
else:
raise ValueError(f"Unknown file extension for {path}")
@classmethod
def load(cls, path: str) -> Self:
if path.endswith((".pkl", ".pickle")):
with CustomFileHandler(path, "rb") as f:
return cls.from_dict(pickle.load(f))
elif path.endswith(".json"):
with CustomFileHandler(path, "r") as f:
return cls.from_dict(json.load(f))
else:
raise ValueError(f"Unknown file extension for {path}")
@dataclass
class TransformerConfig(SerializableDataclass):
enc: str = "gpt2"
ctx_window: int = 1024
d_model: int = 256
n_layers: int = 2
# attn
m_attn: float = 1
n_heads: int = 8
# mlp
m_mlp: float = 4
@property
def d_ff(self) -> int:
return int(self.d_model * self.m_mlp)
@property
def d_attn_qk(self) -> int:
return int(self.d_model * self.m_attn)
@property
def d_attn_v(self) -> int:
return int(self.d_model * self.m_attn)
@property
def d_head_qk(self) -> int:
return safe_div(self.d_attn_qk, self.n_heads)
@property
def d_head_v(self) -> int:
return safe_div(self.d_attn_v, self.n_heads)
def default_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def safe_div(numerator: int, denominator: int) -> int:
assert numerator % denominator == 0
return numerator // denominator
# ====================
# Attention utilities
# ====================
@cache
def causal_attn_mask(size: int, device: Device = "cpu") -> Tensor:
return torch.tril(torch.ones(size, size)).bool().to(device)
def split_heads(Z: Tensor, n_heads: int) -> Tensor:
batch, seq, d_attn = Z.shape
return Z.reshape(batch, seq, n_heads, d_attn // n_heads)
def merge_heads(Z: Tensor) -> Tensor:
batch, seq, n_heads, d_head = Z.shape
return Z.reshape(batch, seq, n_heads * d_head)
# ===================================
# MLP utilities
# ===================================
def gelu(x: Tensor) -> Tensor:
return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
# return x * torch.sigmoid(1.702 * x)
# ========================================
# Sampling, padding and related utilities
# ========================================
def prep_input_and_right_pad_for_forward_pass(
X: list[list[int]], device: Device = "cpu"
) -> tuple[Tensor, Tensor]:
# Helper function. The two tensors returned by this function are suitable to be passed to
# Transformer.forward.
return prep_input_and_pad(X, "right", device)
def prep_input_and_pad(
X: list[list[int]], pad_side: str, device: Device = "cpu"
) -> tuple[Tensor, Tensor]:
# X is a list of tokenized prompts; prompts may have unequal lengths. This function will
# left-pad X by putting "-1" in all the slots where a prompt is shorter than the longest prompt.
# Then convert X into a tensor of int tokens. Then build the pad tensor by looking for the
# "-1"s. Then fill the "-1"s in X with "0"s so the embedding layer doesn't get upset.
max_len = max([len(prompt) for prompt in X])
def pad(x):
padding = [-1] * (max_len - len(x))
if pad_side == "left":
return padding + x
elif pad_side == "right":
return x + padding
else:
raise ValueError(f"pad_side must be 'left' or 'right', not {pad_side}")
X_tensor = torch.LongTensor([pad(prompt) for prompt in X]).to(device)
pad = X_tensor == -1
X_tensor = torch.where(X_tensor == -1, 0, X_tensor)
return X_tensor, pad
def prep_pos_from_pad_and_prev_lens(pad: Tensor, prev_lens: Tensor) -> Tensor:
# pad has shape b x s, prev_lens has shape b x 1.
# For position embedding, we need a tensor of shape (b x s) whose
# entries are the positions of X in the sequence. When sampling with
# prompts of unequal length, X is left padded with pad tokens. The
# position tensor needs to take that into account.
pos = torch.logical_not(pad).long().cumsum(dim=-1) - 1
pos = torch.where(pos == -1, 0, pos)
return pos + prev_lens
def nucleus_sample(logits: Tensor, top_p: float) -> Tensor:
# top_p in [0,1] is the total probability mass of top outputs.
# based on https://nn.labml.ai/sampling/nucleus.html
# input shape: [..., n_vocab] -> output shape: [...]
sorted_logits, idxs = torch.sort(logits, dim=-1, descending=True)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
# logic to ensure there is always at least one token with nonzero
# probability when selecting nucleus.
p0 = cum_probs[..., 0]
top_p = torch.where(p0 > top_p, p0, top_p)[..., None]
# sampling
do_not_sample = cum_probs > top_p
sorted_logits = sorted_logits.masked_fill(do_not_sample, float("-inf"))
dist = Categorical(logits=sorted_logits)
samples = dist.sample()
tokens = idxs.gather(-1, samples.unsqueeze(-1)).squeeze(-1)
return tokens
# ===============
# Layer Norm
# ===============
class Norm(nn.Module):
"""LayerNorm reimplementation with hooks."""
def __init__(
self,
size: int,
eps: float = 1e-5,
device: Device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
kwargs = {"device": device, "dtype": dtype}
self.size = size
self.weight = nn.Parameter(torch.empty(size, **kwargs)) # type: ignore[arg-type]
self.bias = nn.Parameter(torch.empty(size, **kwargs)) # type: ignore[arg-type]
self.eps = eps
def forward(self, x: Tensor, hooks: NormalizationHooks = None) -> Tensor:
if hooks is None:
hooks = NormalizationHooks()
# always do norm in fp32
orig_dtype = x.dtype
x = x.float()
x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, length]
x = hooks.post_mean_subtraction(x)
scale = torch.sqrt((x**2).mean(dim=-1, keepdim=True) + self.eps)
scale = hooks.scale(scale)
x = x / scale
x = hooks.post_scale(x)
ret = x * self.weight + self.bias
return ret.to(orig_dtype)
def apply_layernorm_foldin(ln: Norm, linears: list[nn.Linear]) -> None:
# folds in a layernorm weight/bias into the next linear layer.
# ln(x) = W_ln * (x - x.mean())/(x.std()) + b_ln
# linear(ln(x)) = W_linear * (W_ln * (x - x.mean())/(x.std()) + b_ln) + b_linear
W_ln = ln.weight.float()
b_ln = ln.bias.float()
for linear in linears:
W_linear = linear.weight.float()
b_linear = linear.bias.float()
W_composed = W_linear * W_ln[None, :]
b_composed = None
b_composed = b_linear + W_linear @ b_ln
# should only copy after new weights are calculated
linear.weight.data.copy_(W_composed)
linear.bias.data.copy_(b_composed)
ln.weight.data[:] = 1
ln.bias.data[:] = 0
# ===========================================
# Attention layers and associated components
# ===========================================
@dataclass
class KeyValueCache:
"""KV cache to save on compute"""
K_cache: Tensor | None = None # b x s_old x d
V_cache: Tensor | None = None # b x s_old x d
pad_cache: Tensor | None = None # b x s_old
def update(self, K: Tensor, V: Tensor, pad: Tensor):
# K, V have shape: b x (s_new - s_old) x d
# pad has shape: b x (s_new - s_old)
new = self.K_cache is None
self.K_cache = K if new else torch.cat([self.K_cache, K], dim=1)
self.V_cache = V if new else torch.cat([self.V_cache, V], dim=1)
self.pad_cache = pad if new else torch.cat([self.pad_cache, pad], dim=1)
return self.K_cache, self.V_cache, self.pad_cache
class MultiHeadedDotProductSelfAttention(nn.Module):
"""A configurable multi-headed dot product attention layer."""
def __init__(
self,
cfg: TransformerConfig,
layer_idx: int,
device: Device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.n_heads = cfg.n_heads
# make layers
kwargs = {"device": device, "dtype": dtype}
self.q_proj = nn.Linear(cfg.d_model, cfg.d_attn_qk, **kwargs)
self.k_proj = nn.Linear(cfg.d_model, cfg.d_attn_qk, **kwargs)
self.v_proj = nn.Linear(cfg.d_model, cfg.d_attn_v, **kwargs)
self.out_proj = nn.Linear(cfg.d_attn_v, cfg.d_model, **kwargs)
self.qk_scale = 1 / np.sqrt(np.sqrt(cfg.d_head_qk))
self.cfg = cfg
def forward(
self,
X: Tensor,
kv_cache: KeyValueCache | None = None,
pad: Tensor | None = None,
hooks: AttentionHooks = AttentionHooks(),
) -> tuple[Tensor, KeyValueCache]:
Q = self.q_proj(X)
K = self.k_proj(X)
V = self.v_proj(X)
# update KV cache
if kv_cache is None:
kv_cache = KeyValueCache()
K, V, pad = kv_cache.update(K, V, pad)
# split apart heads, rescale QK
Q = split_heads(Q, self.n_heads) * self.qk_scale
Q = hooks.q(Q) # bshd
K = split_heads(K, self.n_heads) * self.qk_scale
K = hooks.k(K) # bshd
V = split_heads(V, self.n_heads)
V = hooks.v(V) # bshd
# useful for calculations below
n_queries, n_keys = Q.shape[1], K.shape[1]
# softmax multi-headed dot product attention
pre_softmax = torch.einsum("bqhd,bkhd -> bhqk", Q, K)
# apply causal attention mask
M = causal_attn_mask(n_keys, device=X.device)
M = M[None, None, -n_queries:] # make M broadcastable to batch, head
pre_softmax = pre_softmax.masked_fill(torch.logical_not(M), float("-inf"))
# apply pad mask
if pad is not None and torch.any(pad):
# we only mask out pad tokens for non-pad query tokens
# (because masking all pad tokens => empty rows => NaNs later)
pad_mask = torch.bitwise_xor(pad[:, None, :], pad[:, :, None])
# make pad broadcastable on head dim, and slice for current queries only
pad_mask = pad_mask[:, None, -n_queries:]
# apply pad mask
pre_softmax = pre_softmax.masked_fill(pad_mask, float("-inf"))
pre_softmax = torch.einsum("bhqk->bqkh", pre_softmax)
pre_softmax = hooks.qk_logits(pre_softmax)
pre_softmax = pre_softmax.float() # for numerical stability
if hooks.qk_softmax_denominator.is_empty():
attn = torch.softmax(pre_softmax, dim=-2)
else:
# factor out softmax in order to hook
pre_softmax_max = torch.max(pre_softmax, -2, keepdim=True)[0].detach()
numerator = torch.exp(pre_softmax - pre_softmax_max)
denominator = numerator.sum(dim=-2, keepdim=True)
denominator = hooks.qk_softmax_denominator(denominator)
attn = numerator / denominator
attn = attn.to(Q.dtype)
attn = hooks.qk_probs(attn)
out = torch.einsum("bqkh,bkhd->bqhd", attn, V)
out = hooks.v_out(out)
out = merge_heads(out) # concatenate results from all heads
# final output projection
return self.out_proj(out), kv_cache
# =====================================
# MLP layers and associated components
# =====================================
class MLP(nn.Module):
"""An MLP for a transformer is a simple two-layer network."""
def __init__(
self, cfg: TransformerConfig, device: Device | None = None, dtype: torch.dtype | None = None
):
super().__init__()
kwargs = {"device": device, "dtype": dtype}
self.in_layer = nn.Linear(cfg.d_model, cfg.d_ff, **kwargs)
self.out_layer = nn.Linear(cfg.d_ff, cfg.d_model, **kwargs)
self.act = gelu
def forward(self, X: Tensor, hooks: MLPHooks = MLPHooks()) -> Tensor:
pre = self.in_layer(X)
pre = hooks.pre_act(pre)
a = self.act(pre)
a = hooks.post_act(a, out_layer=self.out_layer)
out = self.out_layer(a)
return out
# =============
# Transformers
# =============
class TransformerLayer(nn.Module):
def __init__(
self,
cfg: TransformerConfig,
layer_idx: int,
device: Device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
kwargs = {"device": device, "dtype": dtype}
self.cfg = cfg
self.attn = MultiHeadedDotProductSelfAttention(cfg, layer_idx, **kwargs)
self.mlp = MLP(cfg, **kwargs)
self.ln_1 = Norm(cfg.d_model, **kwargs)
self.ln_2 = Norm(cfg.d_model, **kwargs)
self.layer_idx = layer_idx
def simplify(self) -> None:
ln_1_linears: list[Any] = [
self.attn.q_proj,
self.attn.k_proj,
self.attn.v_proj,
]
apply_layernorm_foldin(self.ln_1, ln_1_linears)
ln_2_linears: list[Any] = [self.mlp.in_layer]
apply_layernorm_foldin(self.ln_2, ln_2_linears)
def attn_block(
self, X: Tensor, kv_cache: KeyValueCache | None, pad: Tensor | None, hooks: TransformerHooks
) -> Tensor:
ln_X = self.ln_1(X, hooks.resid.torso.ln_attn)
ln_X = hooks.resid.torso.post_ln_attn(ln_X)
attn_delta, kv_cache = self.attn(ln_X, kv_cache, pad, hooks.attn)
attn_delta = hooks.resid.torso.delta_attn(attn_delta)
return attn_delta, kv_cache
def mlp_block(self, X: Tensor, hooks: TransformerHooks) -> Tensor:
ln_X = self.ln_2(X, hooks.resid.torso.ln_mlp)
ln_X = hooks.resid.torso.post_ln_mlp(ln_X)
mlp_delta = self.mlp(ln_X, hooks.mlp)
mlp_delta = hooks.resid.torso.delta_mlp(mlp_delta)
return mlp_delta
def forward(
self,
X: Tensor,
kv_cache: KeyValueCache | None = None,
pad: Tensor | None = None,
hooks: TransformerHooks = TransformerHooks(),
) -> tuple[Tensor, KeyValueCache]:
attn_delta, kv_cache = self.attn_block(X, kv_cache, pad, hooks)
X = X + attn_delta
X = hooks.resid.torso.post_attn(X)
mlp_delta = self.mlp_block(X, hooks)
X = X + mlp_delta
X = hooks.resid.torso.post_mlp(X)
return X, kv_cache
class HiddenState:
"""A hidden state for a transformer. Tracks prompt lengths and KV caches."""
def __init__(self, n_layers: int):
self.prev_lens = 0
self.kv_caches = [None for _ in range(n_layers)]
def set_prev_lens(self, prev_lens) -> None:
self.prev_lens = prev_lens
def __getitem__(self, idx: int):
return self.kv_caches[idx]
def __setitem__(self, idx: int, value: KeyValueCache | None):
self.kv_caches[idx] = value
class Transformer(nn.Module):
def __init__(
self,
cfg: TransformerConfig,
# recomputing is optional, and it trades off compute for memory.
recompute: bool = False,
device: Device | None = None,
dtype: torch.dtype | None = None,
):
super().__init__()
self.cfg = cfg
self.enc = tiktoken.get_encoding(self.cfg.enc)
self.n_vocab = self.enc.n_vocab
self.recompute = recompute
self.dtype = dtype
# build network
kwargs = {"device": device, "dtype": dtype}
self.tok_embed = nn.Embedding(self.n_vocab, cfg.d_model, **kwargs)
self.pos_embed = nn.Embedding(cfg.ctx_window, cfg.d_model, **kwargs)
self.xf_layers = nn.ModuleList(
[TransformerLayer(cfg, idx, **kwargs) for idx in range(cfg.n_layers)]
)
self.final_ln = Norm(cfg.d_model, **kwargs)
self.unembed = nn.Linear(cfg.d_model, self.n_vocab, bias=False, **kwargs)
def simplify(self):
for xf_layer in self.xf_layers:
xf_layer.simplify()
# NOTE: we can't fold layer norm into unembedding layer
# because it has no bias
# apply_layernorm_foldin(self.final_ln, [self.unembed])
@property
def device(self) -> Device:
return next(self.parameters()).device
def set_recompute(self, recompute: bool) -> None:
self.recompute = recompute
def forward(
self,
tokens: Tensor,
H: HiddenState | None = None,
pad: Tensor | None = None,
hooks: TransformerHooks = TransformerHooks(),
) -> tuple[Tensor, HiddenState]:
"""
Forward pass through the transformer!
During evaluation or first forward pass in sampling:
X is expected to be a [batch_size x sequence_length]-shaped LongTensor of encoded prompts.
H is expected to be None.
pad is a [batch_size x sequence_length]-shaped boolean Tensor. "1"s mean "ignore this
token". This parameter must be set if not all encoded prompts in X have the same length.
Note that activations observed by hooks will include padded values.
During sampling after first forward pass:
X is expected to be the new part of the sequences (eg most recently sampled tokens).
H is expected to have KV-caches of all Keys and Values for prior tokens.
pad is expected to be None (new tokens are not pad tokens).
Returns a tuple containing the resulting logits tensor and a new hidden state consisting of a KV cache.
"""
X, H, pad, hooks = self.run_embed(tokens, H, pad, hooks)
X, H, pad, hooks = self.run_torso(X, H, pad, hooks)
return self.run_unembed(X, H, hooks)
def run_embed(
self,
tokens: Tensor,
H: HiddenState | None = None,
pad: Tensor | None = None,
hooks: TransformerHooks = TransformerHooks(),
) -> tuple[Tensor, HiddenState, Tensor | None, TransformerHooks]:
assert tokens.dtype == torch.long, "tokens must be sequences of tokens."
if H is None:
H = HiddenState(self.cfg.n_layers)
if pad is None:
pad = torch.zeros_like(tokens, dtype=torch.bool)
# embedding
X = self.tok_embed(tokens)
# position encoding logic to support sampling with prompts of unequal length.
pos = prep_pos_from_pad_and_prev_lens(pad, H.prev_lens)
seq_lens = (pos[:, -1] + 1).unsqueeze(-1)
assert all(
seq_lens <= self.cfg.ctx_window
), f"sequences must fit in the context window {self.cfg.ctx_window}."
H.set_prev_lens(seq_lens)
X = X + self.pos_embed(pos)
X = hooks.resid.post_emb(X)
return X, H, pad, hooks
def run_torso(
self,
X: Tensor,
H: HiddenState | None,
pad: Tensor | None,
hooks: TransformerHooks,
) -> tuple[Tensor, HiddenState, Tensor | None, TransformerHooks]:
# transformer torso
for i, xf_layer in enumerate(self.xf_layers):
hooks_layer_i = deepcopy(hooks).bind(layer=i)
if self.recompute:
X, H[i] = checkpoint(xf_layer, X, H[i], pad, hooks_layer_i)
else:
X, H[i] = xf_layer(X, H[i], pad, hooks_layer_i)
return X, H, pad, hooks
def run_ln_f(
self,
X: Tensor,
H: HiddenState | None,
hooks: TransformerHooks,
) -> tuple[Tensor, HiddenState, TransformerHooks]:
X = self.final_ln(X, hooks.resid.ln_f)
X = hooks.resid.post_ln_f(X)
return X, H, hooks
def run_unembed(
self,
X: Tensor,
H: HiddenState | None,
hooks: TransformerHooks,
) -> tuple[Tensor, HiddenState]:
# unembedding
X, H, hooks = self.run_ln_f(X, H, hooks)
X = self.unembed(X)
X = hooks.logits(X)
return X, H
def sample(
self,
prompts: str | list[str] | list[int] | list[list[int]],
num_tokens: int = 5,
temperature: float = 1.0,
top_p: float | None = None,
hooks: TransformerHooks = TransformerHooks(),
) -> dict[str, Any]:
"""
Sampling with the transformer!
If top_p is set, then nucleus sampling is used.
Otherwise, the sampling will be Categorical.
If temperature=0, sampling is deterministic (and top_p is ignored).
(Warning: when using torch.use_deterministic_algorithms(True),
nucleus sampling will throw an error. It depends on torch.cumsum,
which unfortunately has no deterministic implementation in torch.)
Output is a dict {'tokens': list[list[int]], 'strings': list[str]}
"""
prompts = [prompts] if isinstance(prompts, str) else prompts
if isinstance(prompts[0], str):
X: list[list[int]] = [self.enc.encode(prompt) for prompt in prompts]
elif isinstance(prompts[0], int):
X = [prompts]
else:
X = prompts
X, pad = prep_input_and_pad(X, "left", self.device)
H = None
beta = 1 / max(temperature, 1e-10)
out = {
"tokens": [[] for _ in prompts],
"strings": ["" for _ in prompts],
}
# sampling loop
for _ in range(num_tokens):
with torch.no_grad():
# get logits
Y, H = self.forward(X, H, pad, hooks=hooks)
logits = Y[:, -1] * beta
# sampling only works if logits are floats
logits = logits.float()
# perform sampling
if temperature == 0:
tokens = torch.argmax(logits, dim=-1)
elif top_p is not None:
tokens = nucleus_sample(logits, top_p)
else:
tokens = Categorical(logits=logits).sample()
X, pad = tokens.unsqueeze(-1), None
for batch_idx, token in enumerate(tokens):
out["tokens"][batch_idx].append(token.item())
out["strings"][batch_idx] += self.enc.decode([token.item()])
return out
@classmethod
def load(
cls,
name_or_path: str,
device: Device | None = None,
dtype: torch.dtype | None = None,
simplify: bool = False,
simplify_kwargs: dict[str, Any] | None = None,
) -> "Transformer":
if name_or_path.startswith("https://"):
path = name_or_path
else:
path = f"https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/{name_or_path.replace('-', '/')}"
xf = cls.from_checkpoint(
path,
device=device,
dtype=dtype,
)
if simplify:
if simplify_kwargs is None:
simplify_kwargs = {}
xf.simplify(**simplify_kwargs)
return xf
def save_checkpoint(
self,
path: str,
) -> None:
self.cfg.save(osp.join(path, "config.json"))
pieces_path = osp.join(path, "model_pieces")
for k, v in self.state_dict().items():
with CustomFileHandler(osp.join(pieces_path, f"{k}.pt"), "wb") as f:
torch.save(v, f)
def load_state_from_checkpoint(
self, path: str, device: Device | None = None, dtype: torch.dtype | None = None
):
pieces_path = osp.join(path, "model_pieces")
piece_names = set(self.state_dict().keys())
piece_files = [f"{k}.pt" for k in piece_names]
if dtype is not None:
assert isinstance(dtype, torch.dtype), "Must provide valid dtype."
device = device or self.device
with ThreadPoolExecutor(max_workers=50) as executor:
k_to_future = {
fname[: -len(".pt")]: executor.submit(
_load_piece, osp.join(pieces_path, fname), device, dtype
)
for fname in piece_files
}
d = {k: future.result() for k, future in k_to_future.items()}
self.load_state_dict(d)
@classmethod
def from_checkpoint(
cls, path: str, device: Device | None = None, dtype: torch.dtype | None = None
) -> "Transformer":
if device is None:
device = default_device()
cfg = TransformerConfig.load(osp.join(path, "config.json"))
xf = cls(cfg, device=device, dtype=dtype)
xf.load_state_from_checkpoint(path, device=device, dtype=dtype)
return xf
def _load_piece(
file_path: str, device: Device, dtype: torch.dtype | None
) -> tuple[str, torch.Tensor]:
disk_cache_path = osp.join(
"/tmp/neuron-explainer-model-pieces-cache", file_path.replace("https://", "")
)
if not file_exists(disk_cache_path):
copy_to_local_cache(file_path, disk_cache_path)
with CustomFileHandler(disk_cache_path, "rb") as f:
t = torch.load(f, map_location=device)
if dtype is not None:
t = t.to(dtype)
return t