in ignite/metrics/multilabel_confusion_matrix.py [0:0]
def _check_input(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndimension() < 2:
raise ValueError(
f"y_pred must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y.ndimension() < 2:
raise ValueError(
f"y must at least have shape (batch_size, num_classes (currently set to {self.num_classes}), ...)"
)
if y_pred.shape[0] != y.shape[0]:
raise ValueError(f"y_pred and y have different batch size: {y_pred.shape[0]} vs {y.shape[0]}")
if y_pred.shape[1] != self.num_classes:
raise ValueError(f"y_pred does not have correct number of classes: {y_pred.shape[1]} vs {self.num_classes}")
if y.shape[1] != self.num_classes:
raise ValueError(f"y does not have correct number of classes: {y.shape[1]} vs {self.num_classes}")
if y.shape != y_pred.shape:
raise ValueError("y and y_pred shapes must match.")
valid_types = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
if y_pred.dtype not in valid_types:
raise ValueError(f"y_pred must be of any type: {valid_types}")
if y.dtype not in valid_types:
raise ValueError(f"y must be of any type: {valid_types}")
if not torch.equal(y_pred, y_pred ** 2):
raise ValueError("y_pred must be a binary tensor")
if not torch.equal(y, y ** 2):
raise ValueError("y must be a binary tensor")