in lerobot/common/policies/pi0fast/modeling_pi0fast.py [0:0]
def create_input_tokens(self, state, lang_text, actions=None):
bsize = state.shape[0]
device = state.device
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
discretized = torch.bucketize(state, bins) - 1
discretized = discretized[:, :32]
prefix_texts = []
state_text = []
for txt, disc in zip(lang_text, discretized, strict=False):
cleaned = txt.lower().strip().replace("_", " ")
state_str = " ".join(str(val.item()) for val in disc)
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
state_text.append(f"State: {state_str};\n")
prefix_out = self.paligemma_tokenizer(
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
)
prefix_ids = prefix_out["input_ids"].to(device)
prefix_mask = prefix_out["attention_mask"].to(device)
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()
if actions is not None:
actions_norm = self.normalize_actions(actions)
actions_pad = F.pad(
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
)[:, :, : self.config.max_action_dim]
fast_out = self.fast_tokenizer_wrapper(
actions_pad.cpu(),
)
act_ids = fast_out["input_ids"]
act_mask = fast_out["attention_mask"].to(device)
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
# Replace action with 0 to pad tokens
act_ids = torch.where(
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
self.pad_token_id,
act_ids,
)
eos_token = torch.tensor(
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
).expand(bsize, -1)
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
act_mask = act_mask.to(device)
else:
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
final_ids = torch.cat([prefix_ids, act_ids], dim=1)
final_mask = torch.cat([prefix_mask, act_mask], dim=1)
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}
# Use tokenizer pad function
padded_output = self.paligemma_tokenizer.pad(
batch_inputs, padding="longest", max_length=180, return_tensors="pt"
)
padded_mask = padded_output["attention_mask"]
# define tensor of padding lengths
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)
padded_output["padded_mask"] = padded_output.pop("attention_mask")
padded_output["attention_mask"] = att_mask
# loss is computed not on prefix, and not on padding
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
padded_output["token_type_ids"] = token_type_ids
return padded_output