def create_input_tokens()

in lerobot/common/policies/pi0fast/modeling_pi0fast.py [0:0]


    def create_input_tokens(self, state, lang_text, actions=None):
        bsize = state.shape[0]
        device = state.device
        bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1]
        discretized = torch.bucketize(state, bins) - 1
        discretized = discretized[:, :32]

        prefix_texts = []
        state_text = []
        for txt, disc in zip(lang_text, discretized, strict=False):
            cleaned = txt.lower().strip().replace("_", " ")
            state_str = " ".join(str(val.item()) for val in disc)
            prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n")
            state_text.append(f"State: {state_str};\n")

        prefix_out = self.paligemma_tokenizer(
            prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False
        )
        prefix_ids = prefix_out["input_ids"].to(device)
        prefix_mask = prefix_out["attention_mask"].to(device)
        prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu()

        if actions is not None:
            actions_norm = self.normalize_actions(actions)
            actions_pad = F.pad(
                actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0
            )[:, :, : self.config.max_action_dim]
            fast_out = self.fast_tokenizer_wrapper(
                actions_pad.cpu(),
            )
            act_ids = fast_out["input_ids"]
            act_mask = fast_out["attention_mask"].to(device)

            act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device)
            # Replace action with 0 to pad tokens
            act_ids = torch.where(
                act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens,
                self.pad_token_id,
                act_ids,
            )

            eos_token = torch.tensor(
                [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device
            ).expand(bsize, -1)
            eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1)
            bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt")
            bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device)
            bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device)
            act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1)
            act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1)
            act_mask = act_mask.to(device)
        else:
            act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device)
            act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device)
        final_ids = torch.cat([prefix_ids, act_ids], dim=1)

        final_mask = torch.cat([prefix_mask, act_mask], dim=1)
        batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()}

        # Use tokenizer pad function
        padded_output = self.paligemma_tokenizer.pad(
            batch_inputs, padding="longest", max_length=180, return_tensors="pt"
        )
        padded_mask = padded_output["attention_mask"]

        # define tensor of padding lengths
        att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens

        token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens)

        padded_output["padded_mask"] = padded_output.pop("attention_mask")
        padded_output["attention_mask"] = att_mask
        # loss is computed not on prefix, and not on padding
        padded_output["loss_mask"] = att_mask & padded_output["padded_mask"]
        padded_output["token_type_ids"] = token_type_ids
        return padded_output