ma_policy/normalizers.py (52 lines of code) (raw):

import tensorflow as tf def _mean_std_update_size(x, axes): x_shape = tf.shape(x) x_dims_to_reduce = tf.gather(x_shape, axes) size = tf.reduce_prod(x_dims_to_reduce) return size def _interpolate(old, new, old_weight, scaled_weight): return old * old_weight + new * scaled_weight def _std_from_mean_and_square(mean, square): var_est = tf.to_float(square) - tf.square(mean) return tf.sqrt(tf.maximum(var_est, 1e-2)) class EMAMeanStd(object): """ Calculates an Exponential Moving Average for each argument with exponential coefficient `beta`. The forward relation is: mean = beta * old_mean + (1.0 - beta) * observation The algorithm removes the bias introduced from setting ema[-1] = 0.0 Note: `beta` parameter is defined with respect to a single observation within a batch if `per_element_update=True` (if a batch has 1000 elements of an observation, this is considered to be a 1000 updates), else it is considered to be the size of an update for a full batch (1 update if `per_element_update=False`). """ def __init__(self, beta, scope="ema", reuse=None, epsilon=1e-6, per_element_update=False, shape=(), version=1): self._version = version self._per_element_update = per_element_update with tf.variable_scope(scope, reuse=reuse): # Expected value of x self._biased_mean = tf.get_variable( dtype=tf.float32, shape=shape, initializer=tf.constant_initializer(0.0), name="mean", trainable=False) # Expected value of x^2 self._biased_sq = tf.get_variable( dtype=tf.float32, shape=shape, initializer=tf.constant_initializer(0.0), name="sq", trainable=False) # How to integrate observations of x over time self._one_minus_beta = 1.0 - beta # Weight placed on ema[-1] == 0.0 which we divide out to debias self._debiasing_term = tf.get_variable( dtype=tf.float32, shape=shape, initializer=tf.constant_initializer(0.0), name="debiasing_term", trainable=False) self.shape = shape # the stored mean and square are biased due to setting ema[-1] = 0.0, # we correct for this by dividing by the debiasing term: self.mean = self._biased_mean / tf.maximum(self._debiasing_term, epsilon) self.std = _std_from_mean_and_square(mean=self.mean, square=self._biased_sq / tf.maximum(self._debiasing_term, epsilon)) def update_op(self, x, axes=(0,)): scaled_weight = tf.cast(self._one_minus_beta, tf.float64) if self._per_element_update: # many updates were done at once in a batch, so we figure out what power # to raise `1-beta` to. # using the fact that for small 1.0 - beta we have: # 1 - beta^N ~= (1.0 - beta) * N size = _mean_std_update_size(x, axes) scaled_weight *= tf.cast(size, tf.float64) one = tf.constant(1.0, dtype=tf.float64) old_weight = one - scaled_weight old_weight_fp32 = tf.to_float(old_weight) scaled_weight_fp32 = tf.to_float(scaled_weight) return tf.group( # increment the running debiasing term by the contribution of the initial ema[-1] == 0.0 observation # (e.g. boost the observed value by how much it was initially discounted on step 1) tf.assign(self._debiasing_term, tf.to_float(_interpolate(old=tf.cast(self._debiasing_term, tf.float64), new=one, old_weight=old_weight, scaled_weight=scaled_weight))), # do an interpolation on the expected value of X tf.assign(self._biased_mean, _interpolate(old=self._biased_mean, new=tf.reduce_mean(tf.to_float(x), axis=axes), old_weight=old_weight_fp32, scaled_weight=scaled_weight_fp32)), # do an interpolation on the expected value of X^2 tf.assign(self._biased_sq, _interpolate(old=self._biased_sq, new=tf.reduce_mean(tf.square(tf.to_float(x)), axis=axes), old_weight=old_weight_fp32, scaled_weight=scaled_weight_fp32)), )