def update()

in torchrecipes/vision/image_classification/metrics/multilabel_accuracy.py [0:0]


    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        """Updates the state with predictions and target.
        Args:
            preds: tensor of shape (B, C) where each value is either logit or
                class probability.
            target: tensor of shape (B, C), which is one-hot / multi-label
                encoded.
        """
        assert preds.shape == target.shape, (
            "predictions and  target must be of the same shape. "
            f"Got preds({preds.shape}) vs target({target.shape})."
        )
        num_classes = target.shape[1]
        assert (
            num_classes >= self._top_k
        ), f"top-k({self._top_k}) is greater than the number of classes({num_classes})"
        preds, target = self._format_inputs(preds, target)

        # pyre-ignore[16]: torch.Tensor has attribute topk
        _, top_idx = preds.topk(self._top_k, dim=1, largest=True, sorted=True)

        # pyre-ignore[16]: Accuracy has attribute correct
        self.correct += (
            torch.gather(target, dim=1, index=top_idx[:, : self._top_k])
            .max(dim=1)
            .values.sum()
            .item()
        )
        # pyre-ignore[16]: Accuracy has attribute total
        self.total += preds.shape[0]