def update()

in ppo_ewma/reward_normalizer.py [0:0]


    def update(self, x):
        batch_mean = x.mean(dim=0)
        batch_var = x.var(dim=0, unbiased=False)
        batch_count = th.tensor([x.shape[0]], device=x.device, dtype=th.float32)
        if self.distributed:
            # flatten+unflatten so we just need one allreduce
            flat = tu.flatten_tensors([batch_mean, batch_var, batch_count])
            flat = flat.to(device=tu.dev())  # Otherwise all_mean_ will fail
            tu.all_mean_(flat)
            tu.unflatten_to(flat, [batch_mean, batch_var, batch_count])
            batch_count *= dist.get_world_size()
        self.update_from_moments(batch_mean, batch_var, batch_count[0])