in tzrec/metrics/grouped_auc.py [0:0]
def compute(self) -> torch.Tensor:
"""Compute the metric."""
if not dist.is_initialized():
preds = torch.cat(
[data["preds"] for data in self.eval_data] # pyre-ignore [29]
)
target = torch.cat(
[data["target"] for data in self.eval_data] # pyre-ignore [29]
)
grouping_key = torch.cat(
[data["grouping_key"] for data in self.eval_data] # pyre-ignore [29]
)
else:
preds, target, grouping_key = (
self.eval_data[0], # pyre-ignore [29]
self.eval_data[1], # pyre-ignore [29]
self.eval_data[2], # pyre-ignore [29]
)
sorted_grouping_key, indices = torch.sort(grouping_key)
sorted_preds = preds[indices]
sorted_target = target[indices]
_, counts = torch.unique_consecutive(sorted_grouping_key, return_counts=True)
counts = counts.tolist()
grouped_preds = torch.split(sorted_preds, counts)
grouped_target = torch.split(sorted_target, counts)
aucs = []
for preds, target in zip(grouped_preds, grouped_target):
mean_target = torch.mean(target.to(torch.float32)).item()
if mean_target > 0 and mean_target < 1:
aucs.append(_binary_auroc_compute((preds, target), None))
sum_gauc = torch.sum(torch.Tensor(aucs))
# gather metric data across processes
if dist.is_initialized() and dist.get_world_size() > 1:
group_cnt = len(aucs)
gather_metric_list = [
torch.empty_like(sum_gauc.cuda())
if dist.get_backend() == "nccl"
else torch.empty_like(sum_gauc)
for _ in range(dist.get_world_size())
]
gather_group_count = [
torch.empty_like(torch.Tensor([group_cnt]).cuda())
if dist.get_backend() == "nccl"
else torch.empty_like(torch.Tensor([group_cnt]))
for _ in range(dist.get_world_size())
]
dist.all_gather(
gather_metric_list,
sum_gauc.cuda() if dist.get_backend() == "nccl" else sum_gauc,
)
dist.all_gather(
gather_group_count,
torch.Tensor([group_cnt]).cuda()
if dist.get_backend() == "nccl"
else torch.Tensor([group_cnt]),
)
total_sum_gauc = torch.sum(torch.stack(gather_metric_list))
total_group_cnt = torch.sum(torch.stack(gather_group_count))
mean_gauc = total_sum_gauc / total_group_cnt
else:
mean_gauc = sum_gauc / len(aucs)
return mean_gauc