neuron_explainer/scripts/download_from_hf.py (84 lines of code) (raw):
import json
import os.path as osp
import click
import torch
from transformers import GPT2LMHeadModel
from neuron_explainer.file_utils import CustomFileHandler
from neuron_explainer.models.transformer import TransformerConfig
EXCLUDES = [".attn.bias", ".attn.masked_bias"]
ALL_MODELS = [
"gpt2/small",
"gpt2/medium",
"gpt2/large",
"gpt2/xl",
]
def get_hf_model(model_name: str) -> GPT2LMHeadModel:
_, model_size = model_name.split("/")
hf_name = "gpt2" if model_size == "small" else f"gpt2-{model_size}"
model = GPT2LMHeadModel.from_pretrained(hf_name)
return model
# ====================================
# Conversion from HuggingFace format
# ====================================
def convert(hf_sd: dict) -> dict:
"""convert state_dict from HuggingFace format to our format"""
n_layers = max([int(k.split(".")[2]) for k in hf_sd.keys() if ".h." in k]) + 1
hf_to_ours = dict()
hf_to_ours["wte"] = "tok_embed"
hf_to_ours["wpe"] = "pos_embed"
hf_to_ours["ln_f"] = "final_ln"
hf_to_ours["lm_head"] = "unembed"
for i in range(n_layers):
hf_to_ours[f"h.{i}"] = f"xf_layers.{i}"
hf_to_ours["attn.c_attn"] = "attn.linear_qkv"
hf_to_ours["attn.c_proj"] = "attn.out_proj"
hf_to_ours["mlp.c_fc"] = "mlp.in_layer"
hf_to_ours["mlp.c_proj"] = "mlp.out_layer"
sd = dict()
for k, v in hf_sd.items():
if any(x in k for x in EXCLUDES):
continue
if "weight" in k and ("attn" in k or "mlp" in k):
v = v.T
k = k.replace("transformer.", "")
for hf_part, part in hf_to_ours.items():
k = k.replace(hf_part, part)
if "attn.linear_qkv." in k:
qproj, kproj, vproj = v.chunk(3, dim=0)
sd[k.replace(".linear_qkv.", ".q_proj.")] = qproj
sd[k.replace(".linear_qkv.", ".k_proj.")] = kproj
sd[k.replace(".linear_qkv.", ".v_proj.")] = vproj
else:
sd[k] = v
return sd
def download(model_name: str, save_dir: str) -> None:
assert model_name in ALL_MODELS, f"Must use valid model size, not {model_name=}"
print(f"Downloading and converting model {model_name} to {save_dir}...")
print(f"Getting HuggingFace model {model_name}...")
model = get_hf_model(model_name)
hf_config = model.config
base_config = dict(
enc="gpt2",
ctx_window=1024,
# attn
m_attn=1,
# mlp
m_mlp=4,
)
cfg = TransformerConfig(
**base_config, # type: ignore
d_model=hf_config.n_embd,
n_layers=hf_config.n_layer,
n_heads=hf_config.n_head,
)
print("Converting state_dict...")
sd = convert(model.state_dict())
print(f"Saving model to {save_dir}...")
# save to file with config
pieces_path = osp.join(save_dir, model_name, "model_pieces")
for k, v in sd.items():
with CustomFileHandler(osp.join(pieces_path, f"{k}.pt"), "wb") as f:
torch.save(v, f)
fname_cfg = osp.join(save_dir, model_name, "config.json")
with CustomFileHandler(fname_cfg, "w") as f:
f.write(json.dumps(cfg.__dict__))
@click.command()
@click.argument("save_dir", type=click.Path(exists=False, file_okay=False))
def download_all(save_dir: str) -> None:
for model_size in ALL_MODELS:
download(model_size, save_dir)
if __name__ == "__main__":
download_all()