def load_metrics()

in scripts/plotting/plot_sweep.py [0:0]


def load_metrics(nconfigs, split_by_job_id: bool = False) -> Metrics:
    """Retrieve metrics into a clustered format (given clustered runs)."""
    pool = ThreadPoolExecutor(min(MAX_XP_LOADERS, len(nconfigs)))
    pget = functools.partial(get_all_logs, nconfigs, split_by_job_id=split_by_job_id)
    all_xp_logs = pool.map(pget, list(nconfigs.keys()))
    
    metrics = {}
    for experiment_id, xp_logs in all_xp_logs:
        if split_by_job_id:
            # Combine the job ID with `experiment_id` to obtain the final ID
            for job_id, logs in xp_logs.items():
                new_id = f"job_id={job_id}"
                if experiment_id == "ALL":
                    pass
                elif isinstance(experiment_id, str):
                    new_id = ",".join([experiment_id, new_id])
                else:
                    assert isinstance(experiment_id, tuple)
                    new_id = experiment_id + (new_id,)
                metrics[new_id] = logs
        else:
            assert len(xp_logs) == 1 and next(iter(xp_logs)) == -1
            metrics[experiment_id] = xp_logs[-1]            
        
    return metrics