in ignite/metrics/precision.py [0:0]
def update(self, output: Sequence[torch.Tensor]) -> None:
self._check_shape(output)
self._check_type(output)
y_pred, y = output[0].detach(), output[1].detach()
if self._type == "binary":
y_pred = y_pred.view(-1)
y = y.view(-1)
elif self._type == "multiclass":
num_classes = y_pred.size(1)
if y.max() + 1 > num_classes:
raise ValueError(
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
f" and element in y has invalid class = {y.max().item() + 1}."
)
y = to_onehot(y.view(-1), num_classes=num_classes)
indices = torch.argmax(y_pred, dim=1).view(-1)
y_pred = to_onehot(indices, num_classes=num_classes)
elif self._type == "multilabel":
# if y, y_pred shape is (N, C, ...) -> (C, N x ...)
num_classes = y_pred.size(1)
y_pred = torch.transpose(y_pred, 1, 0).reshape(num_classes, -1)
y = torch.transpose(y, 1, 0).reshape(num_classes, -1)
# Convert from int cuda/cpu to double on self._device
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
y = y.to(dtype=torch.float64, device=self._device)
correct = y * y_pred
all_positives = y_pred.sum(dim=0)
if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=0)
if self._type == "multilabel":
if not self._average:
self._true_positives = torch.cat([self._true_positives, true_positives], dim=0) # type: torch.Tensor
self._positives = torch.cat([self._positives, all_positives], dim=0) # type: torch.Tensor
else:
self._true_positives += torch.sum(true_positives / (all_positives + self.eps))
self._positives += len(all_positives)
else:
self._true_positives += true_positives
self._positives += all_positives
self._updated = True