def convert()

in neuron_explainer/scripts/download_from_hf.py [0:0]


def convert(hf_sd: dict) -> dict:
    """convert state_dict from HuggingFace format to our format"""
    n_layers = max([int(k.split(".")[2]) for k in hf_sd.keys() if ".h." in k]) + 1

    hf_to_ours = dict()
    hf_to_ours["wte"] = "tok_embed"
    hf_to_ours["wpe"] = "pos_embed"
    hf_to_ours["ln_f"] = "final_ln"
    hf_to_ours["lm_head"] = "unembed"
    for i in range(n_layers):
        hf_to_ours[f"h.{i}"] = f"xf_layers.{i}"
    hf_to_ours["attn.c_attn"] = "attn.linear_qkv"
    hf_to_ours["attn.c_proj"] = "attn.out_proj"
    hf_to_ours["mlp.c_fc"] = "mlp.in_layer"
    hf_to_ours["mlp.c_proj"] = "mlp.out_layer"

    sd = dict()
    for k, v in hf_sd.items():
        if any(x in k for x in EXCLUDES):
            continue
        if "weight" in k and ("attn" in k or "mlp" in k):
            v = v.T
        k = k.replace("transformer.", "")
        for hf_part, part in hf_to_ours.items():
            k = k.replace(hf_part, part)
        if "attn.linear_qkv." in k:
            qproj, kproj, vproj = v.chunk(3, dim=0)
            sd[k.replace(".linear_qkv.", ".q_proj.")] = qproj
            sd[k.replace(".linear_qkv.", ".k_proj.")] = kproj
            sd[k.replace(".linear_qkv.", ".v_proj.")] = vproj
        else:
            sd[k] = v

    return sd