def get_hiplot_data()

in scripts/plotting/plot_sweep.py [0:0]


def get_hiplot_data(data, all_y):
    """
    Obtain data for HiPlot plots.

    :param data: The raw experiment data, as obtained by `load_experiments()`.
    :param all_y: List of fields from the underlying dataframes, that are the main
        quantities of interest we want to analyze.

    :returns: A tuple `(global_avg_rank, per_experiment_data, all_y_keys)` where:
        - `global_avg_rank` is the HiPlot data one can use to visualize the average
          rank of a given run across all experiments (each experiment corresponds to
          a figure displayed by the `plot()` function). This makes it possible to
          identify what works best "on average".
        - `per_experiment_data` is the HiPlot data for each experiment. It is used to
          look more closely at what happens in a specific setting.
        - `all_y_keys` is the set of all HiPlot keys associated to the quantities of
          interest in `all_y`.
    """
    # Map experiment_id -> list of HiPlot dicts for each run in the experiment.
    # This HiPlot data is used to display per-experiment results.
    per_experiment_data = {}
    # Map run_id -> list of HiPlot dicts for each experiment the run appears in.
    # This HiPlot data is used to display global results (across all experiments) based
    # on the average rank of each run.
    run_to_hip = defaultdict(list)
    # All keys associated to fields found in `all_y` (which we want in all HiPlot plots).
    # This set is filled below as we add these keys.
    all_y_keys = set()

    # Gather HiPlot data for each split.
    for experiment_id, experiment_data in data.items():
        # The list containing the HiPlot data for each run in the experiment.
        hip_data = []
        # Map y -> list containing the corresponding value at the end of each run.
        all_y_val = defaultdict(list)

        # All options used to identify individual runs in this experiment.
        hip_args = list(set(option_val.split("=", 1)[0] for run_id in experiment_data for option_val in run_id))

        # List whose i-th element is a HiPlot dict for the i-th run in the experiment.
        # This HiPlot dict can also be found in `run_to_hip`, and will be used to plot
        # rank statistics across all experiments.
        run_index_to_hip = []

        for run_id, run_data in experiment_data.items():
            run_y = defaultdict(list)  # map y -> all values of y for this run (at end of training)
            for idx, (path, logs) in enumerate(run_data.items()):
                flags = logs["config"]["flags"]
                if idx == 0:
                    # This is the first job in the run: use it as reference.
                    # We store in `hip_data` the dict to be used in `per_experiment_data`.
                    hip_data.append({k: flags[LONG_NAMES.get(k, k)] for k in hip_args})
                    # And we make a copy to be used in `run_to_hip`.
                    h = copy.deepcopy(hip_data[-1])
                    # Ensure that for a given run_id, all HiPlot dictionaries share the same args.
                    for other_h in run_to_hip[run_id]:
                        assert all(other_h[k] == v for k, v in h.items()), (other_h, h)
                    run_to_hip[run_id].append(h)
                    run_index_to_hip.append(h)
                else:
                    # Make sure that other experiments in the same run are consistent in terms of args.
                    assert all(flags[LONG_NAMES.get(k, k)] == hip_data[-1][k] for k in hip_args)

                df = logs["df"][[x] + all_y].dropna()
                for y in all_y:
                    # Use a rolling window to extract a stable value for the quantity of interest.
                    df[y] = df[y].rolling(ROLLING_WINDOW, min_periods=0).mean()
                    run_y[y].append(df[y].iloc[-1])

            # The data we use for HiPlot is the mean across all jobs in the run.
            for y in all_y:
                y_mean = np.mean(run_y[y])
                y_key = SHORT_NAMES.get(y, y) + "_last"
                all_y_keys.add(y_key)
                hip_data[-1][y_key] = y_mean
                all_y_val[y].append(y_mean)

        # Compute the rank of each run within this experiment, for each quantity of interest y.
        for y, y_data in all_y_val.items():
            y_short = SHORT_NAMES.get(y, y)
            n = len(y_data)
            for rank, i in enumerate(np.argsort(y_data)):
                rank_norm = rank / (n - 1) if n > 1 else 0  # normalize rank between 0 and 1
                y_key = y_short + "_rank"
                all_y_keys.add(y_key)
                # Add the normalized rank to the HiPlot dictionary associated to the i-th run
                # in this experiment.
                run_index_to_hip[i][y_key] = rank_norm

        per_experiment_data[experiment_id] = hip_data

    # Compute the average rank across all experiments, for run IDs that appear in all experiments.
    for run_id, hip_dicts in run_to_hip.items():
        if len(hip_dicts) != len(data):
            # This run does not appear in all experiments: ignore it here.
            assert len(hip_dicts) < len(data)
            continue
        for y in all_y:
            y_key = SHORT_NAMES.get(y, y) + "_rank"
            # Compute average rank across all experiments.
            avg_rank = np.mean([h[y_key] for h in hip_dicts])
            # The first dict is taken as reference to hold the average rank. Others will not be used.
            hip_dicts[0][y_key] = avg_rank

    # List of the HiPlot data for each run containing their global average rank for each quantity
    # of interest. We keep only the first dict associated to each run since it is the one holding
    # this data.
    global_avg_rank = [hip_dicts[0] for hip_dicts in run_to_hip.values()]

    return global_avg_rank, per_experiment_data, all_y_keys