in utils/evaluate.py [0:0]
def get_rows_cols_agg(files, raw_files):
to_plot_col = set()
rows = []
for raw_file, file in zip(raw_files, files):
const_row = dict()
const_row["experiment"] = file.split(
"/")[0][len("experiment") + 1:]
# hyperparameters
for col in file.split("/")[1:]:
key = "_".join(col.split("_")[:-1])
value = col.split("_")[-1]
const_row[key] = str_to_val(value)
history = History.from_file(raw_file)
# prepare for line plots
history_to_plot = {
key: history[:, key]
for key in history[0].keys()
if is_plot(key, history[:, key])
}
# Checking same number epoch
for i, (k, v) in enumerate(history_to_plot.items()):
if i == 0:
old_k = k
old_len = len(v)
if old_len != len(v):
raise ValueError(
f"Number of epochs not the same for (at least) {old_k} and {k}."
)
for epoch, history_per_epoch in enumerate(
cont_tuple_to_tuple_cont(history_to_plot)
):
row = const_row.copy()
row["epochs"] = epoch
for key, value in history_per_epoch.items():
row[key] = value
to_plot_col.add(key)
rows.append(row)
return rows, None, list(to_plot_col)