def forward()

in lerobot/common/policies/vqbet/modeling_vqbet.py [0:0]


    def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
        # Input validation.
        assert set(batch).issuperset({"observation.state", "observation.images"})
        batch_size, n_obs_steps = batch["observation.state"].shape[:2]
        assert n_obs_steps == self.config.n_obs_steps

        # Extract image feature (first combine batch and sequence dims).
        img_features = self.rgb_encoder(
            einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
        )
        # Separate batch and sequence dims.
        img_features = einops.rearrange(
            img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
        )

        # Arrange prior and current observation step tokens as shown in the class docstring.
        # First project features to token dimension.
        rgb_tokens = self.rgb_feature_projector(
            img_features
        )  # (batch, obs_step, number of different cameras, projection dims)
        input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))]
        input_tokens.append(
            self.state_projector(batch["observation.state"])
        )  # (batch, obs_step, projection dims)
        input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps))
        # Interleave tokens by stacking and rearranging.
        input_tokens = torch.stack(input_tokens, dim=2)
        input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")

        len_additional_action_token = self.config.n_action_pred_token - 1
        future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1)

        # add additional action query tokens for predicting future action chunks
        input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)

        # get action features (pass through GPT)
        features = self.policy(input_tokens)
        # len(self.config.input_features) is the number of different observation modes.
        # this line gets the index of action prompt tokens.
        historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
            self.config.input_features
        )

        # only extract the output tokens at the position of action query:
        # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
        # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251).
        # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
        if len_additional_action_token > 0:
            features = torch.cat(
                [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
            )
        else:
            features = features[:, historical_act_pred_index]
        # pass through action head
        action_head_output = self.action_head(features)
        # if rollout, VQ-BeT don't calculate loss
        if rollout:
            return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape(
                batch_size, self.config.action_chunk_size, -1
            )
        # else, it calculate overall loss (bin prediction loss, and offset loss)
        else:
            output = batch[ACTION][:, self.select_target_actions_indices]
            loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
            return action_head_output, loss