in ignite/metrics/accuracy.py [0:0]
def _check_type(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y.ndimension() + 1 == y_pred.ndimension():
num_classes = y_pred.shape[1]
if num_classes == 1:
update_type = "binary"
self._check_binary_multilabel_cases((y_pred, y))
else:
update_type = "multiclass"
elif y.ndimension() == y_pred.ndimension():
self._check_binary_multilabel_cases((y_pred, y))
if self._is_multilabel:
update_type = "multilabel"
num_classes = y_pred.shape[1]
else:
update_type = "binary"
num_classes = 1
else:
raise RuntimeError(
f"Invalid shapes of y (shape={y.shape}) and y_pred (shape={y_pred.shape}), check documentation."
" for expected shapes of y and y_pred."
)
if self._type is None:
self._type = update_type
self._num_classes = num_classes
else:
if self._type != update_type:
raise RuntimeError(f"Input data type has changed from {self._type} to {update_type}.")
if self._num_classes != num_classes:
raise ValueError(f"Input data number of classes has changed from {self._num_classes} to {num_classes}")