in scripts/plotting/plot_sweep.py [0:0]
def load_metrics(nconfigs, split_by_job_id: bool = False) -> Metrics:
"""Retrieve metrics into a clustered format (given clustered runs)."""
pool = ThreadPoolExecutor(min(MAX_XP_LOADERS, len(nconfigs)))
pget = functools.partial(get_all_logs, nconfigs, split_by_job_id=split_by_job_id)
all_xp_logs = pool.map(pget, list(nconfigs.keys()))
metrics = {}
for experiment_id, xp_logs in all_xp_logs:
if split_by_job_id:
# Combine the job ID with `experiment_id` to obtain the final ID
for job_id, logs in xp_logs.items():
new_id = f"job_id={job_id}"
if experiment_id == "ALL":
pass
elif isinstance(experiment_id, str):
new_id = ",".join([experiment_id, new_id])
else:
assert isinstance(experiment_id, tuple)
new_id = experiment_id + (new_id,)
metrics[new_id] = logs
else:
assert len(xp_logs) == 1 and next(iter(xp_logs)) == -1
metrics[experiment_id] = xp_logs[-1]
return metrics