def run_test()

in leaderboard/stats.py [0:0]


    def run_test(self, test_type: StatTest):
        """
        The Wilcoxon test is a non-parametric test for comparing matched samples.
        It is a non-parameteric alternative to the Student T-test, which assumes
        that the difference is normally distributed.

        The test determines if two dependent samples are from the same distribution
        by comparing population mean rank differences.
        """
        if test_type == StatTest.WILCOXON:
            stat_test = stats.wilcoxon
        elif test_type == StatTest.STUDENT_T:
            stat_test = stats.ttest_rel
        elif test_type == StatTest.MCNEMAR:
            stat_test = mcnemar_test
        elif test_type == StatTest.SEM:
            stat_test = self.create_standard_error_of_measure()
        elif test_type == StatTest.SEE:
            stat_test = self.create_standard_error_of_estimation()
        else:
            raise ValueError(f"Invalid test: {test_type}")

        results = []
        completed = set()
        model_ids = list(self.data.scored_predictions.keys())
        model_pairs = list(itertools.product(model_ids, model_ids))
        if self.parallel:
            tqdm_position = TESTS.index(test_type.value)
        else:
            tqdm_position = None
        for model_a, model_b in tqdm.tqdm(
            model_pairs, position=tqdm_position, desc=f"Test: {test_type.value}"
        ):
            if model_a != model_b:
                key = tuple(sorted([model_a, model_b]))
                if key in completed:
                    continue
                completed.add(key)

                model_a_array, model_b_array = self.extract_paired_data(model_a, model_b)
                model_a_array = np.array(model_a_array)
                model_b_array = np.array(model_b_array)

                model_a_score = self.compute_model_scores(model_a)
                model_b_score = self.compute_model_scores(model_b)
                if (model_a_array - model_b_array).sum() == 0:
                    results.append(
                        PairedStats(
                            model_a=model_a,
                            model_b=model_b,
                            key=" ".join(sorted([model_a, model_b])),
                            score_a=model_a_score,
                            score_b=model_b_score,
                            max_score=max(model_a_score, model_b_score),
                            min_score=min(model_a_score, model_b_score),
                            diff=abs(model_a_score - model_b_score),
                            statistic=None,
                            pvalue=None,
                            test=test_type.value,
                            metric=self._metric,
                            fold="dev",
                        )
                    )
                    continue

                if test_type == StatTest.SEE:
                    model_a_skill, model_b_skill, statistic, pvalue = stat_test(model_a, model_b)
                    metadata = {
                        "model_a_skill": model_a_skill,
                        "model_b_skill": model_b_skill,
                    }
                else:
                    statistic, pvalue = stat_test(model_a_array, model_b_array)
                    metadata = None
                results.append(
                    PairedStats(
                        model_a=model_a,
                        model_b=model_b,
                        key=" ".join(sorted([model_a, model_b])),
                        score_a=model_a_score,
                        score_b=model_b_score,
                        max_score=max(model_a_score, model_b_score),
                        min_score=min(model_a_score, model_b_score),
                        diff=abs(model_a_score - model_b_score),
                        statistic=statistic,
                        pvalue=pvalue,
                        test=test_type.value,
                        metric=self._metric,
                        fold="dev",
                        metadata=metadata,
                    )
                )
        return results