def get_layer_collection()

in kfac/python/keras/utils.py [0:0]


def get_layer_collection(model,
                         loss=None,
                         loss_weights=None,
                         fisher_approx=None,
                         layer_collection=None,
                         seed=None):
  """Get layer collection with all layers and loss registered.

  Args:
   model: Keras model whose layers to register. Currently, Conv1D,
     Conv2D, Dense, BatchNormalization, LayerNormalization and Embedding layers
     are supported in a Functional or Sequential model. Other layer types are
     supported as long as they aren't trainable (or don't have weights). Nested
     models are supported.
   loss: Optional Keras (normal or serialized) loss function. Could be a list or
     a dictionary mapping layer names to (normal or serialized) loss functions.
     if there are multiple losses Currently, sparse/normal categorical/binary
     cross entropy and MSE are supported. You must register at least one loss
     with the layer collection before it can be used.
   loss_weights: An optional list of coefficients or a dictionary mapping
     layer names to the coefficient for each loss function. If it is a list,
     there must be a the same number of coefficients as loss functions. If
     it is a dictionary and a coefficient is not given for a loss function,
     a coefficient of 1.0 will be used.
   fisher_approx: An optional list of approximations or a dictionary mapping
     layer name/class to fisher approximation type. If it is a list, there must
     be the same number of approximations as there are layers with trainable
     parameters. For each layer, the approximation is determined as follows:
     if fisher_approx is a dictionary, first we check if the name is in the
     dict, if it isn't found the layer class is checked, if that isn't found
     the default is used. When fisher_approx is a list, the order of the
     approximations must match the order of the layers with trainable parameters
     given by model.layers. None is a valid dict/list entry and indicates to use
     the default approximation for that layer.
   layer_collection: Optional LayerCollection object on which the model and loss
     will be registered.
   seed: Optional integer specifing the TensorFlow random seed. To get
     deterministic behaviour, the seed needs to be set because the targets
     are sampled to approximate the fisher.

  Raises:
   ValueError: If there is a layer with trainable parameters that isn't Conv1D,
     Conv2D, Dense, BatchNormalization, LayerNormalization or Embedding.
   ValueError: If a loss function other than MSE and cross entropy
     variants is used.
   ValueError: If there isn't a one-to-one correspondence between
     loss/loss_weights and output layers, or if loss_weights isn't a list/dict.
   ValueError: If convolutional layers don't use the "channels_last" format.

  Returns:
    A kfac.LayerCollection with the model's layers and loss registered.
  """
  if not layer_collection:
    layer_collection = kfac_layer_collection.LayerCollection()

  if not loss:
    loss = {}
  elif isinstance(loss, dict):
    if set(model.output_names) != set(loss.keys()):
      raise ValueError('Output layer names and loss dict keys don\'t match'
                       ' \nmodel.output_names: {} \nloss dict keys: {}'
                       .format(model.output_names, loss.keys()))
  elif isinstance(loss, list):
    if len(model.output_names) != len(loss):
      raise ValueError('Number of loss dict items doesn\'t match number of '
                       'output layers. \nmodel.output_names: {} \nloss list: '
                       '{}'.format(model.output_names, loss))
    loss = dict(zip(model.output_names, loss))
  else:
    if len(model.output_names) > 1:
      raise ValueError('More output layers than losses. \n'
                       'model.output_names: {} \nloss: {}'
                       .format(model.output_names, loss))
    # When the model is used as a callable, the model's output_names may not
    # match the actual output layer's name. In the one output case, we always
    # want the last layer, so we use the last layer's name.
    loss = {model.layers[-1].name: loss}

  # We want to do a left-to-right depth-first traversal of the model to get the
  # correct flattened order of the layers. The order only matters for the
  # fisher_approx in list form.
  flattened_layers = []
  layer_stack = model.layers[::-1]
  while layer_stack:
    layer = layer_stack.pop()
    if hasattr(layer, 'layers'):
      if layer.name in loss:
        if len(layer.output_names) > 1:
          raise ValueError('Nested models with multiple outputs are '
                           'unsupported.')
        loss[layer.output_names[0]] = loss.pop(layer.name)
      layer_stack += layer.layers[::-1]
    else:
      flattened_layers.append(layer)

  trainable_layer_names = [l.name for l in flattened_layers if
                           l.count_params() and l.trainable]
  fisher_approx = _get_verified_dict(fisher_approx, 'fisher_approx',
                                     trainable_layer_names)
  # The Optimizer class passes in a serialized fisher_approx dictionary, but the
  # user may not. We serialize it so we can use it uniformly.
  fisher_approx = serialize_fisher_approx(fisher_approx)
  loss_weights = _get_verified_dict(loss_weights, 'loss_weights',
                                    model.output_names)

  for layer in flattened_layers:
    if layer.name in fisher_approx:
      approx = fisher_approx[layer.name]
    else:
      approx = fisher_approx.get(
          _CLASS_NAME_PREFIX + layer.__class__.__name__, None)

    register_layer(layer_collection, layer, fisher_approx=approx)

    if layer.name in loss:
      register_loss(layer_collection=layer_collection,
                    layer=layer,
                    loss=loss[layer.name],
                    coeff=loss_weights.get(layer.name, 1.0),
                    seed=seed)

  return layer_collection