def sync_params()

in ppo_ewma/torch_util.py [0:0]


def sync_params(params, src_rank=0, group=dist.group.WORLD, comm=None, use_mpi=False):
    """
    Send parameters from src_rank to all others in the group
    """
    datas = [p.data for p in params]
    flatvec = flatten_tensors(datas)
    if use_mpi:
        if comm is None:
            comm = DEFAULT_COMM
        flatvec = th2np(flatvec)
        comm.Bcast(flatvec, root=0)
        flatvec = np2th(flatvec)
    else:
        dist_broadcast(flatvec, src=src_rank, group=group)
    unflatten_to(flatvec, datas)