in kfac/python/ops/tensormatch/graph_search.py [0:0]
def register_subgraph_layers(layer_collection,
varlist,
subgraph,
user_registered_variables=frozenset(),
reuse=False,
batch_size=None):
"""Walk a subgraph and register all layers to layer_collection.
Args:
layer_collection: A `LayerCollection` to use for registering layers.
varlist: A list of the variables in the graph.
subgraph: The `SubGraph` to search.
user_registered_variables: A set of all the variables the user has manually
registered. No layers using any of these variables should be registered.
reuse: (OPTIONAL) bool. If True, then `layer_collection`
selects a previously registered block with the same key as the key
derived from `params` of that block. If False, a new block is
registered.
batch_size: A `int` representing the batch size. Needs to specified if
registering generic variables that don't match any layer patterns or
if the time/uses dimension is folded into batch. If the time/uses
dimension is merged with batch then this is used to infer number of
uses/time-steps.
Raises:
ValueError: If any variables specified as part of linked groups were not
matched with their group.
If the same variable is used in multiple layers types
(e.g. fully connected and 2d convolution), or if the same variable is
used in multiple layers of a type that doesn't support shared parameters.
AmbiguousRegistrationError: If any variables must be registered as generic
and batch_size is not specified, or if even after filtering, there are
matches with overlapping but unequal sets of variables (see
filter_records).
"""
# List of patterns and binding functions to use when we match one of them
match_register_list = [(gm.matcher_with_consumed(gp.Affine),
record_affine_from_bindings),
(gm.matcher_with_consumed(gp.ScaleAndShift),
record_scale_and_shift_from_bindings),
(gm.matcher_with_consumed(gp.BatchNorm),
record_batch_norm_from_bindings),
(gm.matcher_with_consumed(gp.FusedBatchNormOutput),
record_batch_norm_from_bindings)]
# Patterns return bindings to raw tensors, so we need to be able to map back
# to variables from the tensors those variables reference.
def var_to_tensors(var):
if resource_variable_ops.is_resource_variable(var):
if tf.control_flow_v2_enabled() and hasattr(layer_collection.graph,
'captures'):
# TODO(b/143690035): Note that the "captures" property relies on an
# API which might change.
captures = layer_collection.graph.captures
return [h for vh, h in captures if vh is var.handle]
else:
return [var.handle]
if utils.is_reference_variable(var):
return [tf_ops.internal_convert_to_tensor(var, as_ref=True)]
raise ValueError('%s is not a recognized variable type.' % str(var))
tensors_to_variables = {tensor: var for var in varlist
for tensor in var_to_tensors(var)}
# Get all the ops from the graph.
ops = layer_collection.graph.get_operations()
# Filter out tf.identity ops since otherwise the matcher generates spurious
# matches.
ops = tuple(op for op in ops if not graph_utils.is_identity(op))
# Extract out the output tensors from the ops
tensors = tuple(out for op in ops for out in op.outputs)
# Filter the tensors to include only those in the subgraph.
tensors = subgraph.filter_list(tensors)
# Go through each tensor and try to match each pattern to it.
record_list_dict = dict()
for tensor in tensors:
for match, recfunc in match_register_list:
match_res = match(tensor)
if match_res:
bindings, consumed_tensors = match_res
record = recfunc(bindings, consumed_tensors, tensors_to_variables)
if record is not None:
if record.params not in record_list_dict:
record_list_dict[record.params] = []
record_list_dict[record.params].append(record)
# Filter out records violating any rules.
record_list_dict = filter_records(layer_collection, record_list_dict,
user_registered_variables)
# Register the layers by going through the lists of records for each param.
register_records(layer_collection, record_list_dict, reuse, batch_size)
# Determine which variables were registered either by the user or
# in the current call to register_subgraph_layers.
automatically_registered_variables = {
var
for params in record_list_dict
for var in ensure_sequence(params)
}
registered_variables = (
automatically_registered_variables | user_registered_variables)
# Register any remaining parameters generically.
for variable in varlist:
if variable not in registered_variables:
for specified_grouping in layer_collection.linked_parameters:
assert isinstance(specified_grouping, frozenset)
if variable in specified_grouping and len(specified_grouping) > 1:
raise ValueError(
'Variable {} in linked group {} was not matched.'.format(
variable, specified_grouping))
generic_bad_string = ('generic registrations may be a symptom that the '
'scanner is failing to auto-detect your model. '
'Generic uses a last-resort approximation, and '
'should never be used for common layer types that '
'K-FAC properly supports, such as convs or '
'fully-connected layers.')
if batch_size is None:
raise AmbiguousRegistrationError(
('Tried to register {} as generic without knowledge of batch_size. '
'You can pass batch_size in to fix this error. But please note, '
+ generic_bad_string).format(variable))
logging.warning(('Registering {} as generic because graph scanner '
'couldn\'t match a pattern for it. This can sometimes '
'be caused by the variable not being present in the '
'graph terminating at the registered losses. You might '
'need to pass an explicit list of parameters to tell '
'the system what parameters are actually in your model. '
'Note that ' + generic_bad_string).format(variable))
layer_collection.register_generic(variable, batch_size, reuse=reuse)