in #U57fa#U7840#U6559#U7a0b/A2-#U795e#U7ecf#U7f51#U7edc#U57fa#U672c#U539f#U7406/#U7b2c7#U6b65 - #U6df1#U5ea6#U795e#U7ecf#U7f51#U7edc/src/ch15-DnnOptimization/Level5_BatchNormTest.py [0:0]
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
During training the sample mean and (uncorrected) sample variance are
computed from minibatch statistics and used to normalize the incoming data.
During training we also keep an exponentially decaying running mean of the mean
and variance of each feature, and these averages are used to normalize data
at test-time.
At each timestep we update the running averages for mean and variance using
an exponential decay based on the momentum parameter:
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
Note that the batch normalization paper suggests a different test-time
behavior: they compute sample mean and variance for each feature using a
large number of training images rather than using a running average. For
this implementation we have chosen to use running averages instead since
they do not require an additional estimation step; the torch7 implementation
of batch normalization also uses running averages.
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
#######################################################################
# TODO: Implement the training-time forward pass for batch normalization. #
# Use minibatch statistics to compute the mean and variance, use these #
# statistics to normalize the incoming data, and scale and shift the #
# normalized data using gamma and beta. #
# #
# You should store the output in the variable out. Any intermediates that #
# you need for the backward pass should be stored in the cache variable. #
# #
# You should also use your computed sample mean and variance together with #
# the momentum variable to update the running mean and running variance, #
# storing your result in the running_mean and running_var variables. #
#######################################################################
# Forward pass
# Step 1 - shape of mu (D,)
mu = 1 / float(N) * np.sum(x, axis=0)
# Step 2 - shape of var (N,D)
xmu = x - mu
# Step 3 - shape of carre (N,D)
carre = xmu**2
# Step 4 - shape of var (D,)
var = 1 / float(N) * np.sum(carre, axis=0)
# Step 5 - Shape sqrtvar (D,)
sqrtvar = np.sqrt(var + eps)
# Step 6 - Shape invvar (D,)
invvar = 1. / sqrtvar
# Step 7 - Shape va2 (N,D)
va2 = xmu * invvar
# Step 8 - Shape va3 (N,D)
va3 = gamma * va2
# Step 9 - Shape out (N,D)
out = va3 + beta
running_mean = momentum * running_mean + (1.0 - momentum) * mu
running_var = momentum * running_var + (1.0 - momentum) * var
cache = (mu, xmu, carre, var, sqrtvar, invvar,
va2, va3, gamma, beta, x, bn_param)
elif mode == 'test':
#######################################################################
# TODO: Implement the test-time forward pass for batch normalization. Use #
# the running mean and variance to normalize the incoming data, then scale #
# and shift the normalized data using gamma and beta. Store the result in #
# the out variable. #
#######################################################################
mu = running_mean
var = running_var
xhat = (x - mu) / np.sqrt(var + eps)
out = gamma * xhat + beta
cache = (mu, var, gamma, beta, bn_param)
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache