in src/jobs/distill_t5.py [0:0]
def calculate_loss(self, student_outputs, teacher_outputs, labels, alpha=0.9, temperature=2.0):
"""
Computes the blended distillation + CE loss for seq2seq models.
Args:
student_outputs: ModelOutput from the student (includes logits)
teacher_outputs: ModelOutput from the teacher (includes logits)
labels: Ground-truth target token IDs (shape [batch, tgt_len])
alpha: Weight for distillation loss (vs. CE loss)
temperature: Softmax temperature
Returns:
total_loss: Weighted sum of distillation and CE losses
"""
student_logits = student_outputs.logits # [batch, seq_len, vocab_size]
teacher_logits = teacher_outputs.logits
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1),
ignore_index=self.tokenizer.pad_token_id,
reduction='mean'
)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
# Compute KL loss (batchmean over all tokens)
kl_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction="batchmean"
) * (temperature ** 2) # scale loss back due to temperature
total_loss = alpha * kl_loss + (1.0 - alpha) * ce_loss
return total_loss