def register_scale_and_shift()

in kfac/python/ops/layer_collection.py [0:0]


  def register_scale_and_shift(self,
                               params,
                               inputs,
                               outputs,
                               approx=None,
                               reuse=VARIABLE_SCOPE):
    """Registers a scale and shift operation.

    A scale and shift operation is a parameterized operation of the form

    outputs = scale * inputs + shift ,

    where scale and shift are variables that broadcast to the shape of inputs.

    outputs and inputs must have batch dimension. scale and shift can have
    a corresponding dimension (although they don't need to), but it must
    be 1.

    These kinds of operations appear frequently in various "normalization"
    layers like Layer Normalization. Batch Normalization layers should still
    be registered as "generic".

    Note that this is an experimental feature that hasn't been experimentally
    validated or published on.

    Args:
      params: Variable or 2-tuple of Variables corresponding to the scale
        and possibly shift parameters (scale must be first).  Note that if
        these have a dimension corresponding to the batch dimension of 'inputs'
        and 'outputs', that dimension must be 1.
      inputs: Tensor of shape [batch_size, ...]. Input tensor that is multiplied
        by the scale the scale tensor.
      outputs: Tensor of shape [batch_size, ...]. Final output produced by the
        scale and shift. Must have the same shape as 'inputs'.
      approx: str or None. If not None must be one of "full" or "diagonal".
        The Fisher approximation to use. If None the default value is used.
        (Default: None)
      reuse: bool or str.  If True, this adds 'inputs' and 'outputs' as an
        additional mini-batch/tower of data to use when estimating the Fisher
        block for this layer (which must have already been registered). If
        "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
        (Default: "VARIABLE_SCOPE")

    Raises:
      ValueError: For improper value to 'approx'.
      KeyError: If reuse == True but no FisherBlock found for 'params'.
      ValueError: If reuse == True and FisherBlock found but of the wrong type.
    """
    # TODO(jamesmartens): Consider replacing some of the logic below with calls
    # to tf.broadcast_static_shape.
    if isinstance(params, (tuple, list)):
      scale = params[0]
      shift = params[1]

      has_shift = True

      start_dim = len(outputs.shape) - len(shift.shape)
      if start_dim < 0:
        raise ValueError("Rank of shift cannot exceed that of outputs.")
      if start_dim == 0 and shift.shape[0] != 1:
        raise ValueError("If shift has a batch dimension its value must be 1.")
      broadcast_dims_shift = list(range(1, start_dim))
      for i in range(max(start_dim, 1), len(outputs.shape)):
        if shift.shape[i - start_dim] < outputs.shape[i]:
          if shift.shape[i - start_dim] == 1:
            broadcast_dims_shift.append(i)
          else:
            raise ValueError("It appears that shift param and output have "
                             "incompatible shapes. This is probably due to "
                             "misspecified arguments.")
        elif shift.shape[i - start_dim] > outputs.shape[i]:
          raise ValueError("It appears that shift param and output have "
                           "incompatible shapes. This is probably due to "
                           "misspecified arguments.")
      broadcast_dims_shift = tuple(broadcast_dims_shift)
    else:
      has_shift = False
      scale = params
      broadcast_dims_shift = None

    start_dim = len(inputs.shape) - len(scale.shape)
    if start_dim < 0:
      raise ValueError("Rank of scale cannot exceed that of inputs.")
    if start_dim == 0 and scale.shape[0] != 1:
      raise ValueError("If scale has a batch dimension its value must be 1.")
    broadcast_dims_scale = list(range(1, start_dim))
    for i in range(max(start_dim, 1), len(inputs.shape)):
      if scale.shape[i - start_dim] < inputs.shape[i]:
        if scale.shape[i - start_dim] == 1:
          broadcast_dims_scale.append(i)
        else:
          raise ValueError("It appears that scale param and input have "
                           "incompatible shapes. This is probably due to "
                           "misspecified arguments.")
    broadcast_dims_scale = tuple(broadcast_dims_scale)

    block_type, approx = self._get_block_type(
        params, approx, self.default_scale_and_shift_approximation,
        self._scale_and_shift_approx_to_block_types)

    block = self._register_block(params, block_type(
        self,
        broadcast_dims_scale,
        broadcast_dims_shift=broadcast_dims_shift,
        has_shift=has_shift),
                                 reuse=reuse)
    block.register_additional_tower(inputs, outputs)

    self._add_uses(params, 1)