in ppo_ewma/reward_normalizer.py [0:0]
def update(self, x):
batch_mean = x.mean(dim=0)
batch_var = x.var(dim=0, unbiased=False)
batch_count = th.tensor([x.shape[0]], device=x.device, dtype=th.float32)
if self.distributed:
# flatten+unflatten so we just need one allreduce
flat = tu.flatten_tensors([batch_mean, batch_var, batch_count])
flat = flat.to(device=tu.dev()) # Otherwise all_mean_ will fail
tu.all_mean_(flat)
tu.unflatten_to(flat, [batch_mean, batch_var, batch_count])
batch_count *= dist.get_world_size()
self.update_from_moments(batch_mean, batch_var, batch_count[0])