def compute()

in tzrec/metrics/grouped_auc.py [0:0]


    def compute(self) -> torch.Tensor:
        """Compute the metric."""
        if not dist.is_initialized():
            preds = torch.cat(
                [data["preds"] for data in self.eval_data]  # pyre-ignore [29]
            )
            target = torch.cat(
                [data["target"] for data in self.eval_data]  # pyre-ignore [29]
            )
            grouping_key = torch.cat(
                [data["grouping_key"] for data in self.eval_data]  # pyre-ignore [29]
            )
        else:
            preds, target, grouping_key = (
                self.eval_data[0],  # pyre-ignore [29]
                self.eval_data[1],  # pyre-ignore [29]
                self.eval_data[2],  # pyre-ignore [29]
            )

        sorted_grouping_key, indices = torch.sort(grouping_key)
        sorted_preds = preds[indices]
        sorted_target = target[indices]

        _, counts = torch.unique_consecutive(sorted_grouping_key, return_counts=True)
        counts = counts.tolist()

        grouped_preds = torch.split(sorted_preds, counts)
        grouped_target = torch.split(sorted_target, counts)

        aucs = []
        for preds, target in zip(grouped_preds, grouped_target):
            mean_target = torch.mean(target.to(torch.float32)).item()
            if mean_target > 0 and mean_target < 1:
                aucs.append(_binary_auroc_compute((preds, target), None))
        sum_gauc = torch.sum(torch.Tensor(aucs))

        # gather metric data across processes
        if dist.is_initialized() and dist.get_world_size() > 1:
            group_cnt = len(aucs)
            gather_metric_list = [
                torch.empty_like(sum_gauc.cuda())
                if dist.get_backend() == "nccl"
                else torch.empty_like(sum_gauc)
                for _ in range(dist.get_world_size())
            ]
            gather_group_count = [
                torch.empty_like(torch.Tensor([group_cnt]).cuda())
                if dist.get_backend() == "nccl"
                else torch.empty_like(torch.Tensor([group_cnt]))
                for _ in range(dist.get_world_size())
            ]

            dist.all_gather(
                gather_metric_list,
                sum_gauc.cuda() if dist.get_backend() == "nccl" else sum_gauc,
            )
            dist.all_gather(
                gather_group_count,
                torch.Tensor([group_cnt]).cuda()
                if dist.get_backend() == "nccl"
                else torch.Tensor([group_cnt]),
            )

            total_sum_gauc = torch.sum(torch.stack(gather_metric_list))
            total_group_cnt = torch.sum(torch.stack(gather_group_count))

            mean_gauc = total_sum_gauc / total_group_cnt
        else:
            mean_gauc = sum_gauc / len(aucs)
        return mean_gauc