in neuron_explainer/scripts/download_from_hf.py [0:0]
def convert(hf_sd: dict) -> dict:
"""convert state_dict from HuggingFace format to our format"""
n_layers = max([int(k.split(".")[2]) for k in hf_sd.keys() if ".h." in k]) + 1
hf_to_ours = dict()
hf_to_ours["wte"] = "tok_embed"
hf_to_ours["wpe"] = "pos_embed"
hf_to_ours["ln_f"] = "final_ln"
hf_to_ours["lm_head"] = "unembed"
for i in range(n_layers):
hf_to_ours[f"h.{i}"] = f"xf_layers.{i}"
hf_to_ours["attn.c_attn"] = "attn.linear_qkv"
hf_to_ours["attn.c_proj"] = "attn.out_proj"
hf_to_ours["mlp.c_fc"] = "mlp.in_layer"
hf_to_ours["mlp.c_proj"] = "mlp.out_layer"
sd = dict()
for k, v in hf_sd.items():
if any(x in k for x in EXCLUDES):
continue
if "weight" in k and ("attn" in k or "mlp" in k):
v = v.T
k = k.replace("transformer.", "")
for hf_part, part in hf_to_ours.items():
k = k.replace(hf_part, part)
if "attn.linear_qkv." in k:
qproj, kproj, vproj = v.chunk(3, dim=0)
sd[k.replace(".linear_qkv.", ".q_proj.")] = qproj
sd[k.replace(".linear_qkv.", ".k_proj.")] = kproj
sd[k.replace(".linear_qkv.", ".v_proj.")] = vproj
else:
sd[k] = v
return sd