in scripts/plotting/plot_sweep.py [0:0]
def display_hiplot(name, hiplot_data, max_cols_per_plot=0, mandatory_cols=None, sort_cols=False):
"""
Utility function to display HiPlot data.
The main job of this function is to create multiple plots when there are too
many columns, so as to keep each plot readable.
:param max_cols_per_plot: Maximum number of columns in each plot. Ignored if <= 0.
:param mandatory_cols: List/set of columns that should be visible in each plot
(when present in the data).
:param sort_cols: Whether columns should be sorted (alphabetically).
"""
if mandatory_cols is None:
mandatory_cols = []
all_data = list(hiplot_data)
all_keys = list(set(k for hip_dict in all_data for k in hip_dict))
if sort_cols:
all_keys = sorted(all_keys)
if max_cols_per_plot <= 0 or len(all_keys) <= max_cols_per_plot:
# Easy case: we can fit everything in a single plot.
print(f"*** {name} (1/1) ***")
hiplot.Experiment.from_iterable(all_data).display()
else:
# Must split keys across multiple plots.
must_have = [k for k in all_keys if k in mandatory_cols]
assert max_cols_per_plot > len(must_have) # since all keys in `y_keys` must be displayed
other_keys = [k for k in all_keys if k not in must_have]
n_other_keys_per_plot = max_cols_per_plot - len(must_have)
n_plots = (len(other_keys) - 1) // n_other_keys_per_plot + 1 # total number of plots
start = 0
while start < len(other_keys):
end = start + n_other_keys_per_plot
plotted_keys = other_keys[start:end] + must_have
to_display = [{k: h[k] for k in itertools.chain(plotted_keys, must_have)}
for h in all_data]
plot_idx = start // n_other_keys_per_plot + 1
print(f"*** {name} ({plot_idx} / {n_plots}) ***")
hiplot.Experiment.from_iterable(to_display).display()
start = end