in lerobot/common/policies/pi0fast/modeling_pi0fast.py [0:0]
def __init__(self, config: PI0FASTConfig):
super().__init__()
self.config = config
# TODO: move tokenizers in Policy
fast_tokenizer_path = "physical-intelligence/fast"
pi0_paligemma_path = "google/paligemma-3b-pt-224"
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path)
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path)
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
self.fast_skip_tokens = self.config.fast_skip_tokens
self.max_input_seq_len = self.config.max_input_seq_len
self.action_horizon = self.config.chunk_size
self.action_dim = self.config.action_feature.shape[
0
] # self.config.max_action_dim # self.config.action_feature.shape[0]
precision = config.precision
torch_precision = PRECISION.get(precision, torch.float32)
self.pad_token_id = (
self.paligemma_tokenizer.pad_token_id
if hasattr(self.paligemma_tokenizer, "pad_token_id")
else self.paligemma_tokenizer.eos_token_id
)
paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": precision,
"vocab_size": 257152,
"_attn_implementation": "eager",
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_pytorch_tanh",
"torch_dtype": precision,
"vision_use_head": False,
},
)
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config)
self.pi0_paligemma.prepare_inputs_for_generation = partial(
prepare_inputs_for_generation, self=self.pi0_paligemma
)
# change important stuff in bf16
params_to_change_dtype = [
"language_model",
"vision_tower",
"multi_modal",
]
for name, param in self.pi0_paligemma.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch_precision)
self.set_requires_grad()
self.image_keys = self.config.image_features.keys()
self.ignore_index = self.pi0_paligemma.config.ignore_index
self.padding_side = self.config.padding_side