def mpi_weighted_mean()

in ppo_ewma/logger.py [0:0]


def mpi_weighted_mean(comm, local_name2valcount):
    """
    Perform a weighted average over dicts that are each on a different node
    Input: local_name2valcount: dict mapping key -> (value, count)
    Returns: key -> mean
    """
    local_name2valcount = {
        name: (float(val), count)
        for (name, (val, count)) in local_name2valcount.items()
    }
    all_name2valcount = comm.gather(local_name2valcount)
    if comm.rank == 0:
        name2sum = defaultdict(float)
        name2count = defaultdict(float)
        for n2vc in all_name2valcount:
            for (name, (val, count)) in n2vc.items():
                name2sum[name] += val * count
                name2count[name] += count
        return {name: name2sum[name] / name2count[name] for name in name2sum}
    else:
        return {}