in research/carls/dynamic_normalization.py [0:0]
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
if not input_shape.ndims:
raise ValueError('Input has undefined rank:', input_shape)
ndims = len(input_shape)
# Convert axis to list and resolve negatives
if isinstance(self.axis, int):
self.axis = [self.axis]
for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x
# Validate axes
for x in self.axis:
if x < 0 or x >= ndims:
raise ValueError('Invalid axis: %d' % x)
if len(self.axis) != len(set(self.axis)):
raise ValueError('Duplicate axis: %s' % self.axis)
axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
for x in axis_to_dim:
if axis_to_dim[x] is None:
raise ValueError('Input has undefined `axis` dimension. Input shape: ',
input_shape)
self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim)
if len(axis_to_dim) == 1:
# Single axis batch norm (most common/default use-case)
param_shape = (list(axis_to_dim.values())[0],)
else:
# Parameter shape is the original shape but with 1 in all non-axis dims
param_shape = [
axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
]
self.mean_offset = self._add_offset('mean_offset', param_shape)
self.mean_scale = self._add_scale('mean_scale', param_shape)
if not self.use_batch_normalization:
self.prior_offset = self._add_offset('prior_offset', param_shape)
self.prior_scale = self._add_scale('prior_scale', param_shape)
self.built = True