in src/helpers.py [0:0]
def roc_auc_score(self) -> List[float]:
"""Compute roc-auc score for each task.
Returns: roc-auc score for all tasks
"""
y_pred = torch.cat(self.y_pred, dim=0)
y_true = torch.cat(self.y_true, dim=0)
# This assumes binary case only
y_pred = torch.sigmoid(y_pred)
n_tasks = y_true.shape[1]
scores = []
for task in range(n_tasks):
task_y_true = y_true[:, task].numpy()
task_y_pred = y_pred[:, task].numpy()
scores.append(roc_auc_score(task_y_true, task_y_pred))
return scores