def convert_pi0_checkpoint()

in lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py [0:0]


def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
    # Break down orbax ckpts - they are in OCDBT
    initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
    # process projection params
    keys = [
        "state_proj",
        "action_in_proj",
        "action_out_proj",
        "action_time_mlp_in",
        "action_time_mlp_out",
    ]

    projection_params = {}
    for key in keys:
        kernel_params = initial_params["projection_params"][key]["kernel"]
        bias_params = initial_params["projection_params"][key]["bias"]
        if isinstance(kernel_params, dict):
            weight = kernel_params["value"]
            bias = bias_params["value"]
        else:
            weight = kernel_params
            bias = bias_params
        projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
        projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))

    # Process PaliGemma weights
    paligemma_config = get_paligemma_config(precision)
    paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
        initial_params["paligemma_params"], paligemma_config
    )

    # Process Gemma  weights (at this stage they are unused)
    gemma_config = get_gemma_config(precision)
    gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)

    # Instantiate model from configs

    if "pi0_aloha_sim" in checkpoint_dir:
        pi0_config = PI0Config(
            empty_cameras=2,
            adapt_to_pi_aloha=True,
            use_delta_joint_actions_aloha=False,
        )
    elif "pi0_aloha_towel" in checkpoint_dir:
        pi0_config = PI0Config(
            adapt_to_pi_aloha=True,
            use_delta_joint_actions_aloha=True,
        )
    elif "pi0_base" in checkpoint_dir:
        pi0_config = PI0Config(
            empty_cameras=0,
            adapt_to_pi_aloha=False,
            use_delta_joint_actions_aloha=False,
        )
    else:
        raise ValueError()

    # gemma_config=gemma_config, paligemma_config=paligemma_config)
    pi0_model = PI0Policy(pi0_config)

    paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
    gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
    projection_params = update_keys_with_prefix(projection_params, "model.")

    # load state dict
    torch_dtype = PRECISIONS[precision]
    pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
    pi0_model = pi0_model.to(torch_dtype)
    # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    pi0_model.save_pretrained(output_path, safe_serialization=True)
    # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)

    # assert that model loads properly
    del pi0_model
    PI0Policy.from_pretrained(output_path)