def calculate_loss()

in src/jobs/distill_t5.py [0:0]


    def calculate_loss(self, student_outputs, teacher_outputs, labels, alpha=0.9, temperature=2.0):
        """
        Computes the blended distillation + CE loss for seq2seq models.

        Args:
            student_outputs: ModelOutput from the student (includes logits)
            teacher_outputs: ModelOutput from the teacher (includes logits)
            labels: Ground-truth target token IDs (shape [batch, tgt_len])
            alpha: Weight for distillation loss (vs. CE loss)
            temperature: Softmax temperature

        Returns:
            total_loss: Weighted sum of distillation and CE losses
        """

        student_logits = student_outputs.logits  # [batch, seq_len, vocab_size]
        teacher_logits = teacher_outputs.logits

        ce_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            ignore_index=self.tokenizer.pad_token_id,
            reduction='mean'
        )

        student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)

        # Compute KL loss (batchmean over all tokens)
        kl_loss = F.kl_div(
            student_log_probs,
            teacher_probs,
            reduction="batchmean"
        ) * (temperature ** 2)  # scale loss back due to temperature

        total_loss = alpha * kl_loss + (1.0 - alpha) * ce_loss

        return total_loss