def load_irt()

in leaderboard/www/app.py [0:0]


def load_irt(fold: str = "dev"):
    data_cache.stan_1pl = IrtResults.parse_file(
        DATA_ROOT / conf["irt"]["squad"][fold]["stan"]["1PL"]["full"] / "parameters.json"
    )
    data_cache.pyro_1pl = IrtResults.parse_file(
        DATA_ROOT / conf["irt"]["squad"][fold]["pyro"]["1PL"]["full"] / "parameters.json"
    )
    data_cache.stan_2pl = IrtResults.parse_file(
        DATA_ROOT / conf["irt"]["squad"][fold]["stan"]["2PL"]["full"] / "parameters.json"
    )
    data_cache.pyro_2pl = IrtResults.parse_file(
        DATA_ROOT / conf["irt"]["squad"][fold]["pyro"]["2PL"]["full"] / "parameters.json"
    )

    data_cache.stan_2pl_map = {}
    n_stan_2pl_examples = len(data_cache.stan_2pl.example_ids)
    for str_idx, example_id in data_cache.stan_2pl.example_ids.items():
        # STAN indexing is one based, so convert to zero based
        idx = int(str_idx) - 1
        if idx < 0 or idx > n_stan_2pl_examples - 1:
            raise ValueError(f"Invalid index: {idx}")
        data_cache.stan_2pl_map[example_id] = ExampleStats(
            irt_model=IrtModelType.stan_2pl.value,
            example_id=example_id,
            diff=data_cache.stan_2pl.diff[idx],
            disc=data_cache.stan_2pl.disc[idx],
        )

    data_cache.pyro_2pl_map = {}
    n_pyro_2pl_examples = len(data_cache.pyro_2pl.example_ids)
    for str_idx, example_id in data_cache.pyro_2pl.example_ids.items():
        idx = int(str_idx)
        if idx < 0 or idx > n_pyro_2pl_examples - 1:
            raise ValueError(f"Invalid index: {idx}")
        data_cache.pyro_2pl_map[example_id] = ExampleStats(
            irt_model=IrtModelType.pyro_2pl.value,
            example_id=example_id,
            diff=data_cache.pyro_2pl.diff[idx],
            disc=data_cache.pyro_2pl.disc[idx],
        )

    data_cache.stan_2pl_skill = {}
    n_stan_2pl_models = len(data_cache.stan_2pl.model_ids)
    for str_idx, model_id in data_cache.stan_2pl.model_ids.items():
        idx = int(str_idx) - 1
        if idx < 0 or idx > n_stan_2pl_models - 1:
            raise ValueError(f"Invalid index: {idx}")
        data_cache.stan_2pl_skill[model_id] = ModelStats(
            irt_model=IrtModelType.stan_2pl.value,
            model_id=model_id,
            skill=data_cache.stan_2pl.ability[idx],
        )

    data_cache.pyro_2pl_skill = {}
    n_pyro_2pl_models = len(data_cache.pyro_2pl.model_ids)
    for str_idx, model_id in data_cache.pyro_2pl.model_ids.items():
        idx = int(str_idx)
        if idx < 0 or idx > n_pyro_2pl_models - 1:
            raise ValueError(f"Invalid index: {idx}")
        data_cache.pyro_2pl_skill[model_id] = ModelStats(
            irt_model=IrtModelType.pyro_2pl.value,
            model_id=model_id,
            skill=data_cache.pyro_2pl.ability[idx],
        )