def forward()

in lerobot/common/policies/pi0fast/modeling_pi0fast.py [0:0]


    def forward(self, batch: dict[str, Tensor]):
        device = batch[OBS_STATE].device
        # TODO: keep like this or move to the policy .forward
        images, img_masks = self.prepare_images(batch)

        padded_outs = self.create_input_tokens(
            state=batch[OBS_STATE],
            lang_text=batch["task"],
            actions=batch[ACTION],
        )

        embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs(
            images,
            img_masks,
            padded_outs["input_ids"],
            padded_outs["padded_mask"],
            padded_outs["attention_mask"],
            padded_outs["loss_mask"],
            padded_outs["token_type_ids"],
            padding_side=self.padding_side,
        )
        position_ids = torch.cumsum(pad_masks, dim=1) - 1
        token_type_ids = token_type_ids.to(dtype=torch.int64)
        past_seen_tokens = 0
        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device)
        pad_masks = block_causal_update_causal_mask(
            attention_mask=pad_masks,
            past_key_values=None,
            cache_position=cache_position,
            input_tensor=embs,
            token_type_ids=token_type_ids,
            dtype=self.pi0_paligemma.dtype,
            attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation,
        )
        outputs = self.pi0_paligemma.forward(
            input_ids=None,
            token_type_ids=None,
            attention_mask=pad_masks,
            position_ids=position_ids,
            past_key_values=None,
            inputs_embeds=embs,
            use_cache=False,
            labels=None,
        )

        logits = outputs.logits

        loss_fct = nn.CrossEntropyLoss(reduction="none")

        # Shift left for next-step prediction
        logits = logits[:, :-1, :]
        targets = targets[:, 1:].to(device)  # Shift targets
        loss_mask = loss_mask[:, 1:].to(device)  # Ensure correct shape

        # Compute per-token loss
        token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1))

        # Apply loss mask
        token_loss = token_loss * loss_mask.reshape(-1)

        # Compute final loss
        loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1)

        # Return loss dictionary
        loss_dict = {"ce_loss": loss.item(), "loss": loss}
        return loss_dict