def loss_fn()

in lerobot/common/policies/vqbet/modeling_vqbet.py [0:0]


    def loss_fn(self, pred, target, **kwargs):
        """
        for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.

        predicted_action: predicted action chunk (offset + decoded centroids)
        sampled_centers: sampled centroids (code of RVQ)
        decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
        NT: batch size * T
        T: number of action query tokens, which are process through same GPT
        cbet_logits: probability of all codes in each layer
        """
        action_seq = target
        predicted_action = pred["predicted_action"]
        sampled_centers = pred["sampled_centers"]
        decoded_action = pred["decoded_action"]
        NT = predicted_action.shape[0] * predicted_action.shape[1]

        cbet_logits = pred["cbet_logits"]

        predicted_action = einops.rearrange(
            predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
        )

        action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
        # Figure out the loss for the actions.
        # First, we need to find the closest cluster center for each ground truth action.
        with torch.no_grad():
            state_vq, action_bins = self.vqvae_model.get_code(action_seq)  # action_bins: NT, G

        # Now we can compute the loss.

        # offset loss is L1 distance between the predicted action and ground truth action
        offset_loss = F.l1_loss(action_seq, predicted_action)

        # calculate primary code prediction loss
        cbet_loss1 = self._focal_loss_fn(
            cbet_logits[:, 0, :],
            action_bins[:, 0],
        )
        # calculate secondary code prediction loss
        cbet_loss2 = self._focal_loss_fn(
            cbet_logits[:, 1, :],
            action_bins[:, 1],
        )
        # add all the prediction loss
        cbet_loss = (
            cbet_loss1 * self.config.primary_code_loss_weight
            + cbet_loss2 * self.config.secondary_code_loss_weight
        )

        equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
        equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)

        action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
        vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
        offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
        action_error_max = torch.max(torch.abs(action_seq - predicted_action))

        loss = cbet_loss + self.config.offset_loss_weight * offset_loss

        loss_dict = {
            "loss": loss,
            "classification_loss": cbet_loss.detach().cpu().item(),
            "offset_loss": offset_loss.detach().cpu().item(),
            "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
            "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
            "vq_action_error": vq_action_error.detach().cpu().item(),
            "offset_action_error": offset_action_error.detach().cpu().item(),
            "action_error_max": action_error_max.detach().cpu().item(),
            "action_mse_error": action_mse_error.detach().cpu().item(),
        }
        return loss_dict