in kfac/python/keras/utils.py [0:0]
def register_layer(layer_collection, layer, fisher_approx=None, **kwargs):
"""Get layer collection with all layers and loss registered.
Args:
layer_collection: LayerCollection object on which the layer will be
registered.
layer: Keras layer to register with the layer_collection.
fisher_approx: Option string specifying the fisher approximation type.
**kwargs: Keyword arguments to be forwarded to the layer registration
function.
Raises:
ValueError: If there is a layer with trainable parameters that isn't Conv1D,
Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding.
ValueError: If convolutional layers don't use the "channels_last" format.
Returns:
A kfac.LayerCollection with the model's layers and loss registered.
"""
# The inbound_nodes property is currently deprecated, but appears to be
# supported in non-eager TF 1.x. This may change.
# If there are multiple inbound_nodes, it means the model was used as a
# callable (i.e. y = model(x)). We assume the inputs/outputs from the call
# need to be registered and not the nodes from the original built model or
# any other previous calls, since layers can't be used multiple times
# (RNN-style) with Keras KFAC.
node = layer.inbound_nodes[-1]
pre_activation_output = node.output_tensors
if hasattr(layer, 'activation') and layer.activation != activations.linear:
pre_activation_output = get_parent(pre_activation_output)
# This will allow unsupported layers to be in our model as long as KFAC
# doesn't have to minimize with respect to those parameters.
if layer.count_params() and layer.trainable:
if any(isinstance(tensor, (list, tuple))
for tensor in (node.input_tensors, node.output_tensors)):
raise ValueError('Individual layers can only have 1 input_tensor and 1 '
'output tensor. You are likely using an unsupported '
'layer type. Error on layer {}'.format(layer))
weights = layer.trainable_weights
kwargs.update({
'inputs': node.input_tensors,
'outputs': pre_activation_output,
'params': weights if len(weights) > 1 else weights[0],
'approx': fisher_approx,
})
# TODO(b/133849249) Support RNNs and other shared weight layers.
if isinstance(layer, layers.Dense):
layer_collection.register_fully_connected(**kwargs)
elif isinstance(layer, layers.Embedding):
layer_collection.register_fully_connected(dense_inputs=False, **kwargs)
elif isinstance(layer, (layers.BatchNormalization,
layers.LayerNormalization)):
if not layer.scale:
# With Batch/Layer Normalization, the user can specify if they want
# the input to be scaled and/or shifted after it is normalized.
raise ValueError('Kfac currently does not support batch/layer '
'normalization with scale=False. Error on layer {}'
.format(layer))
# Undo batchnorm by subtracting the shift and diving by scale.
kwargs['inputs'] = ((kwargs['outputs'] - weights[1]) / weights[0]
if layer.center else kwargs['outputs'] / weights)
layer_collection.register_scale_and_shift(**kwargs)
# A learning_phase of 1 or 0 means it's been set. False means it hasn't.
is_phase_set = K.get_value(K.learning_phase()) != False # pylint: disable=g-explicit-bool-comparison
if hasattr(layer, 'fused') and layer.fused and not is_phase_set:
# For the fused implementation of the BatchNormalization, there are
# two ops: one for training and one for inference. When the
# learning_phase is set, during layer creation, there is a
# tf_utils.smart_cond that will only create one of the ops. When the
# learning_phase is not set, it will create a tf.cond with both ops as
# branches. So, when learning_phase is not set, we must add a "use"
# for the gamma/beta variables to account for there being two ops that
# are consumers of the variables. Linked below is the smart_cond in
# BatchNormalization:
# https://github.com/tensorflow/tensorflow/blob/59217f581fdef4e5469a98b62e38f851eac88688/tensorflow/python/keras/layers/normalization.py#L513
# Updated 2019-06-22.
layer_collection._add_uses(weights, 1) # pylint: disable=protected-access
elif all(hasattr(layer, a) for a in
('strides', 'padding', 'dilation_rate')):
if layer.data_format != 'channels_last':
raise ValueError('KFAC currently only supports the "channels_last" '
'data format for convolutional layers. Error on '
'layer {}'.format(layer))
kwargs['padding'] = layer.padding.upper()
kwargs['strides'] = [1] + list(layer.strides) + [1]
kwargs['dilations'] = [1] + list(layer.dilation_rate) + [1]
if isinstance(layer, layers.Conv2D):
layer_collection.register_conv2d(**kwargs)
elif isinstance(layer, layers.Conv1D):
layer_collection.register_conv1d(**kwargs)
# Depthwise and Separable Conv2D are not supported yet because they are
# experimental in tensorflow_kfac.
else:
raise ValueError('Unsupported convolutional layer type: {}'
.format(layer))
# TODO(b/133849240): Support registering any convolution type.
else:
raise ValueError('Unsupported layer type: {}'.format(layer))