def build()

in research/carls/dynamic_normalization.py [0:0]


  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    if not input_shape.ndims:
      raise ValueError('Input has undefined rank:', input_shape)
    ndims = len(input_shape)

    # Convert axis to list and resolve negatives
    if isinstance(self.axis, int):
      self.axis = [self.axis]

    for idx, x in enumerate(self.axis):
      if x < 0:
        self.axis[idx] = ndims + x

    # Validate axes
    for x in self.axis:
      if x < 0 or x >= ndims:
        raise ValueError('Invalid axis: %d' % x)
    if len(self.axis) != len(set(self.axis)):
      raise ValueError('Duplicate axis: %s' % self.axis)

    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
    for x in axis_to_dim:
      if axis_to_dim[x] is None:
        raise ValueError('Input has undefined `axis` dimension. Input shape: ',
                         input_shape)
    self.input_spec = tf.keras.layers.InputSpec(ndim=ndims, axes=axis_to_dim)

    if len(axis_to_dim) == 1:
      # Single axis batch norm (most common/default use-case)
      param_shape = (list(axis_to_dim.values())[0],)
    else:
      # Parameter shape is the original shape but with 1 in all non-axis dims
      param_shape = [
          axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
      ]

    self.mean_offset = self._add_offset('mean_offset', param_shape)
    self.mean_scale = self._add_scale('mean_scale', param_shape)

    if not self.use_batch_normalization:
      self.prior_offset = self._add_offset('prior_offset', param_shape)
      self.prior_scale = self._add_scale('prior_scale', param_shape)

    self.built = True