def roc_auc_score()

in src/helpers.py [0:0]


    def roc_auc_score(self) -> List[float]:
        """Compute roc-auc score for each task.
        Returns: roc-auc score for all tasks
        """
        y_pred = torch.cat(self.y_pred, dim=0)
        y_true = torch.cat(self.y_true, dim=0)
        # This assumes binary case only
        y_pred = torch.sigmoid(y_pred)
        n_tasks = y_true.shape[1]
        scores = []
        for task in range(n_tasks):
            task_y_true = y_true[:, task].numpy()
            task_y_pred = y_pred[:, task].numpy()
            scores.append(roc_auc_score(task_y_true, task_y_pred))
        return scores