in optimum/intel/neural_compressor/trainer.py [0:0]
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
teacher_outputs = inputs.pop("teacher_logits", None)
outputs = model(**inputs)
# Save past state if it exists
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.distillation_config is not None:
student_outputs = self._get_logits(outputs)
if teacher_outputs is not None:
if len(teacher_outputs.shape) == 3 and teacher_outputs.shape[1] == 2:
teacher_outputs = tuple(teacher_outputs.transpose(1, 0))
else:
self.distillation_config.teacher_model.eval()
self.distillation_config.teacher_model.to(model.device)
teacher_outputs = self.distillation_config.teacher_model(**inputs)
teacher_outputs = self._get_logits(teacher_outputs)
if teacher_outputs is not None and self.distillation_callback is not None:
distillation_loss = self.compute_distillation_loss(student_outputs, teacher_outputs)
loss *= self.distillation_callback.criterion.loss_weights[0]
loss += distillation_loss * self.distillation_callback.criterion.loss_weights[1]
loss /= sum(self.distillation_callback.criterion.loss_weights)
if isinstance(outputs, dict):
outputs["loss"] = loss
else:
outputs[0] = loss
return (loss, outputs) if return_outputs else loss