in pyrela/parse_log.py [0:0]
def average_across_seed(logs):
new_logs = defaultdict(list)
for k, v in logs.items():
s = k.rsplit("_", 1)
if len(s) == 2:
name, seed = s
elif len(s) == 1:
name = "default"
seed = s[0]
if not seed.startswith("SEED"):
print("no multiple seeds, omit averaging: ", name)
name = k
new_logs[name].append(v)
for k in new_logs:
vals = new_logs[k]
min_len = np.min([len(v) for v in vals])
assert min_len > 0, min_len
vals = np.stack([v[:min_len] for v in vals])
# print(k, vals.shape)
# new_vals = [np.mean([v[i] for v in vals]) for i in range(len(vals[0]))]
mean = vals.mean(0).tolist()
sem = vals.std(0) / np.sqrt(vals.shape[0])
new_logs[k] = (mean, sem)
# l = list(new_logs.items())
# l = sorted(l, key=lambda x: -x[1][-2])
# summary = [(shorten_name(ll[0]), *ll[1]) for ll in l]
# header = ['name', 'epoch', 'act', 'train', 'buffer', 'score', 'perfect']
# print(tabulate(summary, headers=header))
return new_logs