def average_across_seed()

in pyhanabi/tools/parse_log.py [0:0]


def average_across_seed(logs):
    new_logs = defaultdict(list)
    for k, v in logs.items():
        s = k.rsplit("_", 1)
        if len(s) == 2:
            name, seed = s
        elif len(s) == 1:
            name = "default"
            seed = s[0]
        if not seed.startswith("SEED"):
            name = k
        new_logs[name].append(v)

    for k in new_logs:
        vals = new_logs[k]
        means = []
        sems = []
        max_len = np.max([len(v) for v in vals])
        for i in range(max_len):
            nums = []
            for v in vals:
                if len(v) > i:
                    nums.append(v[i])
            means.append(np.mean(nums))
            if len(nums) == 1:
                sems.append(0)
            else:
                sems.append(np.std(nums) / np.sqrt(len(nums)))
        new_logs[k] = (means, sems)

    return new_logs