def augment()

in scripts/plotting/plot_sweep.py [0:0]


def augment(data, results, x="step", y="mean_episode_return"):
    boundaries = {}
    for env, models in data.items():
        for _model, runs in models.items():
            xmin = min(np.amin(config["df"][x].values) for _, config in runs.items())
            xmax = max(np.amax(config["df"][x].values) for _, config in runs.items())
            if env not in boundaries:
                boundaries[env] = {"xmin": xmin, "xmax": xmax}
            if xmin < boundaries[env]["xmin"]:
                boundaries[env]["xmin"] = xmin
            if xmax > boundaries[env]["xmax"]:
                boundaries[env]["xmax"] = xmax

    for env, models in results.items():
        for model, result in models.items():
            data[env].update(
                {
                    model: {
                        "published_results": {
                            "df": pd.DataFrame(
                                {
                                    x: [
                                        boundaries[env]["xmin"],
                                        boundaries[env]["xmax"],
                                    ],
                                    y: [result, result],
                                }
                            )
                        }
                    }
                }
            )