leaderboard/cat_sampling_stability.py [112:186]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
        CACHE_IRT[irt_type] = irt_model
        return irt_model


class InfoSampler(Sampler):
    def __init__(
        self, *, fold: str, irt_type: str, subject_ids: List[str],
    ):
        super().__init__(fold)
        if irt_type != "2PL":
            raise ValueError("Not impl for more than 2PL")
        self._chosen_items = set()
        self._irt_type = irt_type
        self._remaining_items = set(self._item_ids)
        self._subject_ids = subject_ids

    def estimate_theta(self, subject_id: str) -> float:
        responses = cached_leaderboard_predictions("dev").scored_predictions[subject_id][
            "exact_match"
        ]

        def likelihood(theta):
            total = 0
            n = 0
            correct = 0
            irt_model = cached_irt(self._irt_type)
            for item_id in self._chosen_items:
                item_stats = irt_model.example_stats[item_id]
                prob = prob_2pl(skill=theta, diff=item_stats.diff, disc=item_stats.disc)
                if responses[item_id] == 1:
                    if prob > 0.5:
                        correct += 1
                    total += np.log(prob)
                else:
                    if prob < 0.5:
                        correct += 1
                    total += np.log(1 - prob)
                n += 1
            return -total

        return optimize.minimize(likelihood, 0, method="L-BFGS-B").x[0]

    def compute_item_information(self, subject_skill: float, item_id: str) -> float:
        item_stats = cached_irt(self._irt_type).example_stats[item_id]
        prob = prob_2pl(skill=subject_skill, disc=item_stats.disc, diff=item_stats.diff)
        return item_stats.disc ** 2 * prob * (1 - prob)

    def compute_sum_information(self, subject_skills: Dict[str, float]) -> Dict[str, float]:
        item_infos = defaultdict(float)
        for skill in subject_skills.values():
            for item_id in self._remaining_items:
                item_infos[item_id] += self.compute_item_information(skill, item_id)
        return item_infos

    def initial_items(self, n_items: int):
        item_stats = {}
        for item_id in self._item_ids:
            item_stats[item_id] = cached_irt(self._irt_type).example_stats[item_id]

        sorted_items = sorted(item_stats.values(), key=lambda item: item.disc, reverse=True)
        return [item.example_id for item in sorted_items[:n_items]]

    def sample(self, n_items: int):
        if n_items < len(self._chosen_items):
            raise ValueError("cannot choose fewer items than last queried")
        elif n_items == len(self._chosen_items):
            return list(self._chosen_items)
        elif len(self._chosen_items) == 0:
            selected_items = self.initial_items(n_items)
            self._remaining_items = self._remaining_items - set(selected_items)
            self._chosen_items = self._chosen_items | set(selected_items)
            return selected_items
        else:
            subject_skills = {}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



leaderboard/sampling_stability.py [110:184]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
        CACHE_IRT[irt_type] = irt_model
        return irt_model


class InfoSampler(Sampler):
    def __init__(
        self, *, fold: str, irt_type: str, subject_ids: List[str],
    ):
        super().__init__(fold)
        if irt_type != "2PL":
            raise ValueError("Not impl for more than 2PL")
        self._chosen_items = set()
        self._irt_type = irt_type
        self._remaining_items = set(self._item_ids)
        self._subject_ids = subject_ids

    def estimate_theta(self, subject_id: str) -> float:
        responses = cached_leaderboard_predictions("dev").scored_predictions[subject_id][
            "exact_match"
        ]

        def likelihood(theta):
            total = 0
            n = 0
            correct = 0
            irt_model = cached_irt(self._irt_type)
            for item_id in self._chosen_items:
                item_stats = irt_model.example_stats[item_id]
                prob = prob_2pl(skill=theta, diff=item_stats.diff, disc=item_stats.disc)
                if responses[item_id] == 1:
                    if prob > 0.5:
                        correct += 1
                    total += np.log(prob)
                else:
                    if prob < 0.5:
                        correct += 1
                    total += np.log(1 - prob)
                n += 1
            return -total

        return optimize.minimize(likelihood, 0, method="L-BFGS-B").x[0]

    def compute_item_information(self, subject_skill: float, item_id: str) -> float:
        item_stats = cached_irt(self._irt_type).example_stats[item_id]
        prob = prob_2pl(skill=subject_skill, disc=item_stats.disc, diff=item_stats.diff)
        return item_stats.disc ** 2 * prob * (1 - prob)

    def compute_sum_information(self, subject_skills: Dict[str, float]) -> Dict[str, float]:
        item_infos = defaultdict(float)
        for skill in subject_skills.values():
            for item_id in self._remaining_items:
                item_infos[item_id] += self.compute_item_information(skill, item_id)
        return item_infos

    def initial_items(self, n_items: int):
        item_stats = {}
        for item_id in self._item_ids:
            item_stats[item_id] = cached_irt(self._irt_type).example_stats[item_id]

        sorted_items = sorted(item_stats.values(), key=lambda item: item.disc, reverse=True)
        return [item.example_id for item in sorted_items[:n_items]]

    def sample(self, n_items: int):
        if n_items < len(self._chosen_items):
            raise ValueError("cannot choose fewer items than last queried")
        elif n_items == len(self._chosen_items):
            return list(self._chosen_items)
        elif len(self._chosen_items) == 0:
            selected_items = self.initial_items(n_items)
            self._remaining_items = self._remaining_items - set(selected_items)
            self._chosen_items = self._chosen_items | set(selected_items)
            return selected_items
        else:
            subject_skills = {}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



