in torchrecipes/vision/image_classification/metrics/multilabel_accuracy.py [0:0]
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Updates the state with predictions and target.
Args:
preds: tensor of shape (B, C) where each value is either logit or
class probability.
target: tensor of shape (B, C), which is one-hot / multi-label
encoded.
"""
assert preds.shape == target.shape, (
"predictions and target must be of the same shape. "
f"Got preds({preds.shape}) vs target({target.shape})."
)
num_classes = target.shape[1]
assert (
num_classes >= self._top_k
), f"top-k({self._top_k}) is greater than the number of classes({num_classes})"
preds, target = self._format_inputs(preds, target)
# pyre-ignore[16]: torch.Tensor has attribute topk
_, top_idx = preds.topk(self._top_k, dim=1, largest=True, sorted=True)
# pyre-ignore[16]: Accuracy has attribute correct
self.correct += (
torch.gather(target, dim=1, index=top_idx[:, : self._top_k])
.max(dim=1)
.values.sum()
.item()
)
# pyre-ignore[16]: Accuracy has attribute total
self.total += preds.shape[0]