def download()

in neuron_explainer/scripts/download_from_hf.py [0:0]


def download(model_name: str, save_dir: str) -> None:
    assert model_name in ALL_MODELS, f"Must use valid model size, not {model_name=}"
    print(f"Downloading and converting model {model_name} to {save_dir}...")

    print(f"Getting HuggingFace model {model_name}...")
    model = get_hf_model(model_name)

    hf_config = model.config
    base_config = dict(
        enc="gpt2",
        ctx_window=1024,
        # attn
        m_attn=1,
        # mlp
        m_mlp=4,
    )
    cfg = TransformerConfig(
        **base_config,  # type: ignore
        d_model=hf_config.n_embd,
        n_layers=hf_config.n_layer,
        n_heads=hf_config.n_head,
    )

    print("Converting state_dict...")
    sd = convert(model.state_dict())

    print(f"Saving model to {save_dir}...")
    # save to file with config
    pieces_path = osp.join(save_dir, model_name, "model_pieces")
    for k, v in sd.items():
        with CustomFileHandler(osp.join(pieces_path, f"{k}.pt"), "wb") as f:
            torch.save(v, f)

    fname_cfg = osp.join(save_dir, model_name, "config.json")
    with CustomFileHandler(fname_cfg, "w") as f:
        f.write(json.dumps(cfg.__dict__))