in lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py [0:0]
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)