in leaderboard/www/app.py [0:0]
def load_irt(fold: str = "dev"):
data_cache.stan_1pl = IrtResults.parse_file(
DATA_ROOT / conf["irt"]["squad"][fold]["stan"]["1PL"]["full"] / "parameters.json"
)
data_cache.pyro_1pl = IrtResults.parse_file(
DATA_ROOT / conf["irt"]["squad"][fold]["pyro"]["1PL"]["full"] / "parameters.json"
)
data_cache.stan_2pl = IrtResults.parse_file(
DATA_ROOT / conf["irt"]["squad"][fold]["stan"]["2PL"]["full"] / "parameters.json"
)
data_cache.pyro_2pl = IrtResults.parse_file(
DATA_ROOT / conf["irt"]["squad"][fold]["pyro"]["2PL"]["full"] / "parameters.json"
)
data_cache.stan_2pl_map = {}
n_stan_2pl_examples = len(data_cache.stan_2pl.example_ids)
for str_idx, example_id in data_cache.stan_2pl.example_ids.items():
# STAN indexing is one based, so convert to zero based
idx = int(str_idx) - 1
if idx < 0 or idx > n_stan_2pl_examples - 1:
raise ValueError(f"Invalid index: {idx}")
data_cache.stan_2pl_map[example_id] = ExampleStats(
irt_model=IrtModelType.stan_2pl.value,
example_id=example_id,
diff=data_cache.stan_2pl.diff[idx],
disc=data_cache.stan_2pl.disc[idx],
)
data_cache.pyro_2pl_map = {}
n_pyro_2pl_examples = len(data_cache.pyro_2pl.example_ids)
for str_idx, example_id in data_cache.pyro_2pl.example_ids.items():
idx = int(str_idx)
if idx < 0 or idx > n_pyro_2pl_examples - 1:
raise ValueError(f"Invalid index: {idx}")
data_cache.pyro_2pl_map[example_id] = ExampleStats(
irt_model=IrtModelType.pyro_2pl.value,
example_id=example_id,
diff=data_cache.pyro_2pl.diff[idx],
disc=data_cache.pyro_2pl.disc[idx],
)
data_cache.stan_2pl_skill = {}
n_stan_2pl_models = len(data_cache.stan_2pl.model_ids)
for str_idx, model_id in data_cache.stan_2pl.model_ids.items():
idx = int(str_idx) - 1
if idx < 0 or idx > n_stan_2pl_models - 1:
raise ValueError(f"Invalid index: {idx}")
data_cache.stan_2pl_skill[model_id] = ModelStats(
irt_model=IrtModelType.stan_2pl.value,
model_id=model_id,
skill=data_cache.stan_2pl.ability[idx],
)
data_cache.pyro_2pl_skill = {}
n_pyro_2pl_models = len(data_cache.pyro_2pl.model_ids)
for str_idx, model_id in data_cache.pyro_2pl.model_ids.items():
idx = int(str_idx)
if idx < 0 or idx > n_pyro_2pl_models - 1:
raise ValueError(f"Invalid index: {idx}")
data_cache.pyro_2pl_skill[model_id] = ModelStats(
irt_model=IrtModelType.pyro_2pl.value,
model_id=model_id,
skill=data_cache.pyro_2pl.ability[idx],
)