in kfac/python/ops/layer_collection.py [0:0]
def _register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
"""Validates and registers the layer_key associated with the fisher_block.
Args:
layer_key: A variable or tuple of variables. The key to check for in
existing registrations and to register if valid.
fisher_block: The associated `FisherBlock`.
reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
or 'VARIABLE_SCOPE'.
Raises:
ValueError: If `layer_key` was already registered and reuse is `False`,
if `layer_key` was registered with a different block type, or if
`layer_key` shares any variables with but is not equal to a previously
registered key.
KeyError: If `reuse` is `True` but `layer_key` was not previously
registered.
Returns:
The `FisherBlock` registered under `layer_key`. If `layer_key` was already
registered, this will be the previously registered `FisherBlock`.
"""
if self._finalized:
raise ValueError("You cannot register additional losses or layers after "
"LayerCollection is finalized. Finalization happens "
"after the estimator or optimizer object first uses "
"the data in the LayerCollection. For example, when "
"the minimize() method is called in "
"PeriodicInvCovUpdateKfacOpt.")
if reuse is VARIABLE_SCOPE:
reuse = tf.get_variable_scope().reuse
if reuse is True or (reuse is tf.AUTO_REUSE and
layer_key in self.fisher_blocks):
if layer_key not in self.fisher_blocks:
raise ValueError(
"reuse was True for attempted registration involving variables {}, "
"but no previously registered layer was found for these. Perhaps "
"reuse was set to True by mistake. One way this can happen is if "
"reuse is set to True in the surrounding variable scope."
"".format(layer_key))
result = self.fisher_blocks[layer_key]
if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
raise ValueError(
"Attempted to register FisherBlock of type %s when existing "
"FisherBlock has type %s." % (type(fisher_block), type(result)))
return result
if reuse is False and layer_key in self.fisher_blocks:
raise ValueError("FisherBlock for %s is already in LayerCollection." %
(layer_key,))
# Insert fisher_block into self.fisher_blocks.
if layer_key in self.fisher_blocks:
raise ValueError("Duplicate registration: {}".format(layer_key))
# Raise an error if any variable in layer_key has been registered in any
# other blocks.
variable_to_block = {
var: (params, block)
for (params, block) in self.fisher_blocks.items()
for var in utils.ensure_sequence(params)
}
for variable in utils.ensure_sequence(layer_key):
if variable in variable_to_block:
prev_key, prev_block = variable_to_block[variable]
raise ValueError(
"Attempted to register layer_key {} with block {}, but variable {}"
" was already registered in key {} with block {}.".format(
layer_key, fisher_block, variable, prev_key, prev_block))
self.fisher_blocks[layer_key] = fisher_block
return fisher_block