tzrec/metrics/grouped_auc.py (105 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, List, Tuple
import torch
from torch import distributed as dist
from torchmetrics import Metric
from torchmetrics.functional.classification.auroc import _binary_auroc_compute
def custom_reduce_fx(
data_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Custom reduce func for distributed training. Distribute data to different GPUs.
Args:
data_list (list): list of tensors,
each tensor is a 2d-tensor of shape (3, num_samples).
Returns:
Tensor: a 2d-tensor of shape (3, num_samples_on_one_gpu).
"""
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("RANK", 0))
pred_reduce = []
target_reduce = []
key_reduce = []
for data in data_list:
for i in range(world_size):
key_mask = (
data["grouping_key"][i] % world_size == local_rank # pyre-ignore [6]
)
pred_selected = torch.masked_select(
data["preds"][i], # pyre-ignore [6]
key_mask,
)
target_selected = torch.masked_select(
data["target"][i], # pyre-ignore [6]
key_mask,
)
key_selected = torch.masked_select(
data["grouping_key"][i], # pyre-ignore [6]
key_mask,
)
pred_reduce.append(pred_selected)
target_reduce.append(target_selected)
key_reduce.append(key_selected)
return torch.cat(pred_reduce), torch.cat(target_reduce), torch.cat(key_reduce)
class GroupedAUC(Metric):
"""Grouped AUC."""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.add_state("eval_data", default=[], dist_reduce_fx=custom_reduce_fx)
# pyre-ignore [14]
def update(
self, preds: torch.Tensor, target: torch.Tensor, grouping_key: torch.Tensor
) -> None:
"""Update the metric.
Args:
preds (Tensor): a float 1d-tensor of predictions.
target (Tensor): a integer 1d-tensor of target.
grouping_key (Tensor): a integer 1d-tensor with group id.
"""
self.eval_data.append(
{"preds": preds, "target": target, "grouping_key": grouping_key}
)
def compute(self) -> torch.Tensor:
"""Compute the metric."""
if not dist.is_initialized():
preds = torch.cat(
[data["preds"] for data in self.eval_data] # pyre-ignore [29]
)
target = torch.cat(
[data["target"] for data in self.eval_data] # pyre-ignore [29]
)
grouping_key = torch.cat(
[data["grouping_key"] for data in self.eval_data] # pyre-ignore [29]
)
else:
preds, target, grouping_key = (
self.eval_data[0], # pyre-ignore [29]
self.eval_data[1], # pyre-ignore [29]
self.eval_data[2], # pyre-ignore [29]
)
sorted_grouping_key, indices = torch.sort(grouping_key)
sorted_preds = preds[indices]
sorted_target = target[indices]
_, counts = torch.unique_consecutive(sorted_grouping_key, return_counts=True)
counts = counts.tolist()
grouped_preds = torch.split(sorted_preds, counts)
grouped_target = torch.split(sorted_target, counts)
aucs = []
for preds, target in zip(grouped_preds, grouped_target):
mean_target = torch.mean(target.to(torch.float32)).item()
if mean_target > 0 and mean_target < 1:
aucs.append(_binary_auroc_compute((preds, target), None))
sum_gauc = torch.sum(torch.Tensor(aucs))
# gather metric data across processes
if dist.is_initialized() and dist.get_world_size() > 1:
group_cnt = len(aucs)
gather_metric_list = [
torch.empty_like(sum_gauc.cuda())
if dist.get_backend() == "nccl"
else torch.empty_like(sum_gauc)
for _ in range(dist.get_world_size())
]
gather_group_count = [
torch.empty_like(torch.Tensor([group_cnt]).cuda())
if dist.get_backend() == "nccl"
else torch.empty_like(torch.Tensor([group_cnt]))
for _ in range(dist.get_world_size())
]
dist.all_gather(
gather_metric_list,
sum_gauc.cuda() if dist.get_backend() == "nccl" else sum_gauc,
)
dist.all_gather(
gather_group_count,
torch.Tensor([group_cnt]).cuda()
if dist.get_backend() == "nccl"
else torch.Tensor([group_cnt]),
)
total_sum_gauc = torch.sum(torch.stack(gather_metric_list))
total_group_cnt = torch.sum(torch.stack(gather_group_count))
mean_gauc = total_sum_gauc / total_group_cnt
else:
mean_gauc = sum_gauc / len(aucs)
return mean_gauc