in kfac/python/ops/layer_collection.py [0:0]
def register_scale_and_shift(self,
params,
inputs,
outputs,
approx=None,
reuse=VARIABLE_SCOPE):
"""Registers a scale and shift operation.
A scale and shift operation is a parameterized operation of the form
outputs = scale * inputs + shift ,
where scale and shift are variables that broadcast to the shape of inputs.
outputs and inputs must have batch dimension. scale and shift can have
a corresponding dimension (although they don't need to), but it must
be 1.
These kinds of operations appear frequently in various "normalization"
layers like Layer Normalization. Batch Normalization layers should still
be registered as "generic".
Note that this is an experimental feature that hasn't been experimentally
validated or published on.
Args:
params: Variable or 2-tuple of Variables corresponding to the scale
and possibly shift parameters (scale must be first). Note that if
these have a dimension corresponding to the batch dimension of 'inputs'
and 'outputs', that dimension must be 1.
inputs: Tensor of shape [batch_size, ...]. Input tensor that is multiplied
by the scale the scale tensor.
outputs: Tensor of shape [batch_size, ...]. Final output produced by the
scale and shift. Must have the same shape as 'inputs'.
approx: str or None. If not None must be one of "full" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
ValueError: For improper value to 'approx'.
KeyError: If reuse == True but no FisherBlock found for 'params'.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
# TODO(jamesmartens): Consider replacing some of the logic below with calls
# to tf.broadcast_static_shape.
if isinstance(params, (tuple, list)):
scale = params[0]
shift = params[1]
has_shift = True
start_dim = len(outputs.shape) - len(shift.shape)
if start_dim < 0:
raise ValueError("Rank of shift cannot exceed that of outputs.")
if start_dim == 0 and shift.shape[0] != 1:
raise ValueError("If shift has a batch dimension its value must be 1.")
broadcast_dims_shift = list(range(1, start_dim))
for i in range(max(start_dim, 1), len(outputs.shape)):
if shift.shape[i - start_dim] < outputs.shape[i]:
if shift.shape[i - start_dim] == 1:
broadcast_dims_shift.append(i)
else:
raise ValueError("It appears that shift param and output have "
"incompatible shapes. This is probably due to "
"misspecified arguments.")
elif shift.shape[i - start_dim] > outputs.shape[i]:
raise ValueError("It appears that shift param and output have "
"incompatible shapes. This is probably due to "
"misspecified arguments.")
broadcast_dims_shift = tuple(broadcast_dims_shift)
else:
has_shift = False
scale = params
broadcast_dims_shift = None
start_dim = len(inputs.shape) - len(scale.shape)
if start_dim < 0:
raise ValueError("Rank of scale cannot exceed that of inputs.")
if start_dim == 0 and scale.shape[0] != 1:
raise ValueError("If scale has a batch dimension its value must be 1.")
broadcast_dims_scale = list(range(1, start_dim))
for i in range(max(start_dim, 1), len(inputs.shape)):
if scale.shape[i - start_dim] < inputs.shape[i]:
if scale.shape[i - start_dim] == 1:
broadcast_dims_scale.append(i)
else:
raise ValueError("It appears that scale param and input have "
"incompatible shapes. This is probably due to "
"misspecified arguments.")
broadcast_dims_scale = tuple(broadcast_dims_scale)
block_type, approx = self._get_block_type(
params, approx, self.default_scale_and_shift_approximation,
self._scale_and_shift_approx_to_block_types)
block = self._register_block(params, block_type(
self,
broadcast_dims_scale,
broadcast_dims_shift=broadcast_dims_shift,
has_shift=has_shift),
reuse=reuse)
block.register_additional_tower(inputs, outputs)
self._add_uses(params, 1)