def register_subgraph_layers()

in kfac/python/ops/tensormatch/graph_search.py [0:0]


def register_subgraph_layers(layer_collection,
                             varlist,
                             subgraph,
                             user_registered_variables=frozenset(),
                             reuse=False,
                             batch_size=None):
  """Walk a subgraph and register all layers to layer_collection.

  Args:
    layer_collection: A `LayerCollection` to use for registering layers.
    varlist: A list of the variables in the graph.
    subgraph: The `SubGraph` to search.
    user_registered_variables: A set of all the variables the user has manually
      registered. No layers using any of these variables should be registered.
    reuse: (OPTIONAL) bool. If True, then `layer_collection`
      selects a previously registered block with the same key as the key
      derived from `params` of that block. If False, a new block is
      registered.
    batch_size: A `int` representing the batch size. Needs to specified if
      registering generic variables that don't match any layer patterns or
      if the time/uses dimension is folded into batch. If the time/uses
      dimension is merged with batch then this is used to infer number of
      uses/time-steps.

  Raises:
    ValueError: If any variables specified as part of linked groups were not
      matched with their group.
      If the same variable is used in multiple layers types
      (e.g. fully connected and 2d convolution), or if the same variable is
      used in multiple layers of a type that doesn't support shared parameters.
    AmbiguousRegistrationError: If any variables must be registered as generic
      and batch_size is not specified, or if even after filtering, there are
      matches with overlapping but unequal sets of variables (see
      filter_records).
  """

  # List of patterns and binding functions to use when we match one of them
  match_register_list = [(gm.matcher_with_consumed(gp.Affine),
                          record_affine_from_bindings),
                         (gm.matcher_with_consumed(gp.ScaleAndShift),
                          record_scale_and_shift_from_bindings),
                         (gm.matcher_with_consumed(gp.BatchNorm),
                          record_batch_norm_from_bindings),
                         (gm.matcher_with_consumed(gp.FusedBatchNormOutput),
                          record_batch_norm_from_bindings)]

  # Patterns return bindings to raw tensors, so we need to be able to map back
  # to variables from the tensors those variables reference.
  def var_to_tensors(var):
    if resource_variable_ops.is_resource_variable(var):
      if tf.control_flow_v2_enabled() and hasattr(layer_collection.graph,
                                                  'captures'):
        # TODO(b/143690035): Note that the "captures" property relies on an
        # API which might change.
        captures = layer_collection.graph.captures
        return [h for vh, h in captures if vh is var.handle]
      else:
        return [var.handle]
    if utils.is_reference_variable(var):
      return [tf_ops.internal_convert_to_tensor(var, as_ref=True)]
    raise ValueError('%s is not a recognized variable type.' % str(var))

  tensors_to_variables = {tensor: var for var in varlist
                          for tensor in var_to_tensors(var)}

  # Get all the ops from the graph.
  ops = layer_collection.graph.get_operations()

  # Filter out tf.identity ops since otherwise the matcher generates spurious
  # matches.
  ops = tuple(op for op in ops if not graph_utils.is_identity(op))

  # Extract out the output tensors from the ops
  tensors = tuple(out for op in ops for out in op.outputs)

  # Filter the tensors to include only those in the subgraph.
  tensors = subgraph.filter_list(tensors)

  # Go through each tensor and try to match each pattern to it.
  record_list_dict = dict()
  for tensor in tensors:
    for match, recfunc in match_register_list:
      match_res = match(tensor)
      if match_res:
        bindings, consumed_tensors = match_res
        record = recfunc(bindings, consumed_tensors, tensors_to_variables)
        if record is not None:
          if record.params not in record_list_dict:
            record_list_dict[record.params] = []
          record_list_dict[record.params].append(record)

  # Filter out records violating any rules.
  record_list_dict = filter_records(layer_collection, record_list_dict,
                                    user_registered_variables)

  # Register the layers by going through the lists of records for each param.
  register_records(layer_collection, record_list_dict, reuse, batch_size)

  # Determine which variables were registered either by the user or
  # in the current call to register_subgraph_layers.
  automatically_registered_variables = {
      var
      for params in record_list_dict
      for var in ensure_sequence(params)
  }
  registered_variables = (
      automatically_registered_variables | user_registered_variables)

  # Register any remaining parameters generically.
  for variable in varlist:
    if variable not in registered_variables:
      for specified_grouping in layer_collection.linked_parameters:
        assert isinstance(specified_grouping, frozenset)
        if variable in specified_grouping and len(specified_grouping) > 1:
          raise ValueError(
              'Variable {} in linked group {} was not matched.'.format(
                  variable, specified_grouping))

      generic_bad_string = ('generic registrations may be a symptom that the '
                            'scanner is failing to auto-detect your model. '
                            'Generic uses a last-resort approximation, and '
                            'should never be used for common layer types that '
                            'K-FAC properly supports, such as convs or '
                            'fully-connected layers.')
      if batch_size is None:
        raise AmbiguousRegistrationError(
            ('Tried to register {} as generic without knowledge of batch_size. '
             'You can pass batch_size in to fix this error. But please note, '
             + generic_bad_string).format(variable))
      logging.warning(('Registering {} as generic because graph scanner '
                       'couldn\'t match a pattern for it. This can sometimes '
                       'be caused by the variable not being present in the '
                       'graph terminating at the registered losses. You might '
                       'need to pass an explicit list of parameters to tell '
                       'the system what parameters are actually in your model. '
                       'Note that ' + generic_bad_string).format(variable))
      layer_collection.register_generic(variable, batch_size, reuse=reuse)