in kfac/python/keras/utils.py [0:0]
def get_layer_collection(model,
loss=None,
loss_weights=None,
fisher_approx=None,
layer_collection=None,
seed=None):
"""Get layer collection with all layers and loss registered.
Args:
model: Keras model whose layers to register. Currently, Conv1D,
Conv2D, Dense, BatchNormalization, LayerNormalization and Embedding layers
are supported in a Functional or Sequential model. Other layer types are
supported as long as they aren't trainable (or don't have weights). Nested
models are supported.
loss: Optional Keras (normal or serialized) loss function. Could be a list or
a dictionary mapping layer names to (normal or serialized) loss functions.
if there are multiple losses Currently, sparse/normal categorical/binary
cross entropy and MSE are supported. You must register at least one loss
with the layer collection before it can be used.
loss_weights: An optional list of coefficients or a dictionary mapping
layer names to the coefficient for each loss function. If it is a list,
there must be a the same number of coefficients as loss functions. If
it is a dictionary and a coefficient is not given for a loss function,
a coefficient of 1.0 will be used.
fisher_approx: An optional list of approximations or a dictionary mapping
layer name/class to fisher approximation type. If it is a list, there must
be the same number of approximations as there are layers with trainable
parameters. For each layer, the approximation is determined as follows:
if fisher_approx is a dictionary, first we check if the name is in the
dict, if it isn't found the layer class is checked, if that isn't found
the default is used. When fisher_approx is a list, the order of the
approximations must match the order of the layers with trainable parameters
given by model.layers. None is a valid dict/list entry and indicates to use
the default approximation for that layer.
layer_collection: Optional LayerCollection object on which the model and loss
will be registered.
seed: Optional integer specifing the TensorFlow random seed. To get
deterministic behaviour, the seed needs to be set because the targets
are sampled to approximate the fisher.
Raises:
ValueError: If there is a layer with trainable parameters that isn't Conv1D,
Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding.
ValueError: If a loss function other than MSE and cross entropy
variants is used.
ValueError: If there isn't a one-to-one correspondence between
loss/loss_weights and output layers, or if loss_weights isn't a list/dict.
ValueError: If convolutional layers don't use the "channels_last" format.
Returns:
A kfac.LayerCollection with the model's layers and loss registered.
"""
if not layer_collection:
layer_collection = kfac_layer_collection.LayerCollection()
if not loss:
loss = {}
elif isinstance(loss, dict):
if set(model.output_names) != set(loss.keys()):
raise ValueError('Output layer names and loss dict keys don\'t match'
' \nmodel.output_names: {} \nloss dict keys: {}'
.format(model.output_names, loss.keys()))
elif isinstance(loss, list):
if len(model.output_names) != len(loss):
raise ValueError('Number of loss dict items doesn\'t match number of '
'output layers. \nmodel.output_names: {} \nloss list: '
'{}'.format(model.output_names, loss))
loss = dict(zip(model.output_names, loss))
else:
if len(model.output_names) > 1:
raise ValueError('More output layers than losses. \n'
'model.output_names: {} \nloss: {}'
.format(model.output_names, loss))
# When the model is used as a callable, the model's output_names may not
# match the actual output layer's name. In the one output case, we always
# want the last layer, so we use the last layer's name.
loss = {model.layers[-1].name: loss}
# We want to do a left-to-right depth-first traversal of the model to get the
# correct flattened order of the layers. The order only matters for the
# fisher_approx in list form.
flattened_layers = []
layer_stack = model.layers[::-1]
while layer_stack:
layer = layer_stack.pop()
if hasattr(layer, 'layers'):
if layer.name in loss:
if len(layer.output_names) > 1:
raise ValueError('Nested models with multiple outputs are '
'unsupported.')
loss[layer.output_names[0]] = loss.pop(layer.name)
layer_stack += layer.layers[::-1]
else:
flattened_layers.append(layer)
trainable_layer_names = [l.name for l in flattened_layers if
l.count_params() and l.trainable]
fisher_approx = _get_verified_dict(fisher_approx, 'fisher_approx',
trainable_layer_names)
# The Optimizer class passes in a serialized fisher_approx dictionary, but the
# user may not. We serialize it so we can use it uniformly.
fisher_approx = serialize_fisher_approx(fisher_approx)
loss_weights = _get_verified_dict(loss_weights, 'loss_weights',
model.output_names)
for layer in flattened_layers:
if layer.name in fisher_approx:
approx = fisher_approx[layer.name]
else:
approx = fisher_approx.get(
_CLASS_NAME_PREFIX + layer.__class__.__name__, None)
register_layer(layer_collection, layer, fisher_approx=approx)
if layer.name in loss:
register_loss(layer_collection=layer_collection,
layer=layer,
loss=loss[layer.name],
coeff=loss_weights.get(layer.name, 1.0),
seed=seed)
return layer_collection