in quant/common/metrics.py [0:0]
def update(self, output: Tensor, target: Tensor,
teacher_output: Optional[Tensor] = None, **kwargs: Any) -> None:
"""
Update the loss metric based on the results of the current batch.
Args:
output: the output of the model
target: the target we want the model to predict
teacher_output: teacher output for knowledge distillation
"""
kd_criterion = 0
if teacher_output is not None:
kd_criterion = self.criterion(output, teacher_output, target).item() # type: ignore
if self.accumulate:
self.n_examples += output.shape[0]
if teacher_output is None:
self.total += self.criterion(output, target, reduction='sum').item()
else:
self.total += kd_criterion * output.shape[0] # kd criterion uses batchmean
else:
if teacher_output is None:
self.total = self.criterion(output, target, reduction='mean').item()
else:
self.total = kd_criterion