in pyhanabi/tools/parse_log.py [0:0]
def average_across_seed(logs):
new_logs = defaultdict(list)
for k, v in logs.items():
s = k.rsplit("_", 1)
if len(s) == 2:
name, seed = s
elif len(s) == 1:
name = "default"
seed = s[0]
if not seed.startswith("SEED"):
name = k
new_logs[name].append(v)
for k in new_logs:
vals = new_logs[k]
means = []
sems = []
max_len = np.max([len(v) for v in vals])
for i in range(max_len):
nums = []
for v in vals:
if len(v) > i:
nums.append(v[i])
means.append(np.mean(nums))
if len(nums) == 1:
sems.append(0)
else:
sems.append(np.std(nums) / np.sqrt(len(nums)))
new_logs[k] = (means, sems)
return new_logs