in lerobot/common/policies/vqbet/modeling_vqbet.py [0:0]
def loss_fn(self, pred, target, **kwargs):
"""
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
predicted_action: predicted action chunk (offset + decoded centroids)
sampled_centers: sampled centroids (code of RVQ)
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
NT: batch size * T
T: number of action query tokens, which are process through same GPT
cbet_logits: probability of all codes in each layer
"""
action_seq = target
predicted_action = pred["predicted_action"]
sampled_centers = pred["sampled_centers"]
decoded_action = pred["decoded_action"]
NT = predicted_action.shape[0] * predicted_action.shape[1]
cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
)
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each ground truth action.
with torch.no_grad():
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
# Now we can compute the loss.
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = F.l1_loss(action_seq, predicted_action)
# calculate primary code prediction loss
cbet_loss1 = self._focal_loss_fn(
cbet_logits[:, 0, :],
action_bins[:, 0],
)
# calculate secondary code prediction loss
cbet_loss2 = self._focal_loss_fn(
cbet_logits[:, 1, :],
action_bins[:, 1],
)
# add all the prediction loss
cbet_loss = (
cbet_loss1 * self.config.primary_code_loss_weight
+ cbet_loss2 * self.config.secondary_code_loss_weight
)
equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT)
equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT)
action_mse_error = torch.mean((action_seq - predicted_action) ** 2)
vq_action_error = torch.mean(torch.abs(action_seq - decoded_action))
offset_action_error = torch.mean(torch.abs(action_seq - predicted_action))
action_error_max = torch.max(torch.abs(action_seq - predicted_action))
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
loss_dict = {
"loss": loss,
"classification_loss": cbet_loss.detach().cpu().item(),
"offset_loss": offset_loss.detach().cpu().item(),
"equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(),
"equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(),
"vq_action_error": vq_action_error.detach().cpu().item(),
"offset_action_error": offset_action_error.detach().cpu().item(),
"action_error_max": action_error_max.detach().cpu().item(),
"action_mse_error": action_mse_error.detach().cpu().item(),
}
return loss_dict