def average_across_seed()

in pyrela/parse_log.py [0:0]


def average_across_seed(logs):
    new_logs = defaultdict(list)
    for k, v in logs.items():
        s = k.rsplit("_", 1)
        if len(s) == 2:
            name, seed = s
        elif len(s) == 1:
            name = "default"
            seed = s[0]
        if not seed.startswith("SEED"):
            print("no multiple seeds, omit averaging: ", name)
            name = k
        new_logs[name].append(v)

    for k in new_logs:
        vals = new_logs[k]
        min_len = np.min([len(v) for v in vals])
        assert min_len > 0, min_len
        vals = np.stack([v[:min_len] for v in vals])
        # print(k, vals.shape)
        # new_vals = [np.mean([v[i] for v in vals]) for i in range(len(vals[0]))]
        mean = vals.mean(0).tolist()
        sem = vals.std(0) / np.sqrt(vals.shape[0])
        new_logs[k] = (mean, sem)

    # l = list(new_logs.items())
    # l = sorted(l, key=lambda x: -x[1][-2])
    # summary = [(shorten_name(ll[0]), *ll[1]) for ll in l]
    # header = ['name', 'epoch', 'act', 'train', 'buffer', 'score', 'perfect']
    # print(tabulate(summary, headers=header))
    return new_logs