lib/normalize_ewma.py (42 lines of code) (raw):

import numpy as np import torch import torch.nn as nn class NormalizeEwma(nn.Module): """Normalize a vector of observations - across the first norm_axes dimensions""" def __init__(self, input_shape, norm_axes=2, beta=0.99999, per_element_update=False, epsilon=1e-5): super().__init__() self.input_shape = input_shape self.norm_axes = norm_axes self.epsilon = epsilon self.beta = beta self.per_element_update = per_element_update self.running_mean = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False) self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, dtype=torch.float), requires_grad=False) self.debiasing_term = nn.Parameter(torch.tensor(0.0, dtype=torch.float), requires_grad=False) def reset_parameters(self): self.running_mean.zero_() self.running_mean_sq.zero_() self.debiasing_term.zero_() def running_mean_var(self): debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon) debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon) debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) return debiased_mean, debiased_var def forward(self, input_vector): # Make sure input is float32 input_vector = input_vector.to(torch.float) if self.training: # Detach input before adding it to running means to avoid backpropping through it on # subsequent batches. detached_input = input_vector.detach() batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes))) batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes))) if self.per_element_update: batch_size = np.prod(detached_input.size()[: self.norm_axes]) weight = self.beta ** batch_size else: weight = self.beta self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) mean, var = self.running_mean_var() return (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] def denormalize(self, input_vector): """Transform normalized data back into original distribution""" mean, var = self.running_mean_var() return input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]