def _register_block()

in kfac/python/ops/layer_collection.py [0:0]


  def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
    """Validates and registers the layer_key associated with the fisher_block.

    Args:
      layer_key: A variable or tuple of variables. The key to check for in
          existing registrations and to register if valid.
      fisher_block: The associated `FisherBlock`.
      reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
        or 'VARIABLE_SCOPE'.

    Raises:
      ValueError: If `layer_key` was already registered and reuse is `False`,
        if `layer_key` was registered with a different block type, or if
        `layer_key` shares any variables with but is not equal to a previously
        registered key.
      KeyError: If `reuse` is `True` but `layer_key` was not previously
        registered.

    Returns:
      The `FisherBlock` registered under `layer_key`. If `layer_key` was already
      registered, this will be the previously registered `FisherBlock`.
    """
    if self._finalized:
      raise ValueError("You cannot register additional losses or layers after "
                       "LayerCollection is finalized. Finalization happens "
                       "after the estimator or optimizer object first uses "
                       "the data in the LayerCollection. For example, when "
                       "the minimize() method is called in "
                       "PeriodicInvCovUpdateKfacOpt.")

    if reuse is VARIABLE_SCOPE:
      reuse = tf.get_variable_scope().reuse

    if reuse is True or (reuse is tf.AUTO_REUSE and
                         layer_key in self.fisher_blocks):

      if layer_key not in self.fisher_blocks:
        raise ValueError(
            "reuse was True for attempted registration involving variables {}, "
            "but no previously registered layer was found for these. Perhaps "
            "reuse was set to True by mistake. One way this can happen is if "
            "reuse is set to True in the surrounding variable scope."
            "".format(layer_key))

      result = self.fisher_blocks[layer_key]

      if type(result) != type(fisher_block):  # pylint: disable=unidiomatic-typecheck
        raise ValueError(
            "Attempted to register FisherBlock of type %s when existing "
            "FisherBlock has type %s." % (type(fisher_block), type(result)))
      return result
    if reuse is False and layer_key in self.fisher_blocks:
      raise ValueError("FisherBlock for %s is already in LayerCollection." %
                       (layer_key,))

    # Insert fisher_block into self.fisher_blocks.
    if layer_key in self.fisher_blocks:
      raise ValueError("Duplicate registration: {}".format(layer_key))
    # Raise an error if any variable in layer_key has been registered in any
    # other blocks.
    variable_to_block = {
        var: (params, block)
        for (params, block) in self.fisher_blocks.items()
        for var in utils.ensure_sequence(params)
    }
    for variable in utils.ensure_sequence(layer_key):
      if variable in variable_to_block:
        prev_key, prev_block = variable_to_block[variable]
        raise ValueError(
            "Attempted to register layer_key {} with block {}, but variable {}"
            " was already registered in key {} with block {}.".format(
                layer_key, fisher_block, variable, prev_key, prev_block))
    self.fisher_blocks[layer_key] = fisher_block
    return fisher_block