def register_records()

in kfac/python/ops/tensormatch/graph_search.py [0:0]


def register_records(layer_collection,
                     record_list_dict,
                     reuse=False,
                     batch_size=None):
  """Registers the given records to layer_collection.

  Args:
    layer_collection: A `LayerCollection` to use for registering layers.
    record_list_dict: A dict mapping tuples of variables to lists of
      `MatchRecord`s representing all of the places those variables are used
      in the graph.
    reuse: (OPTIONAL) bool. If True, then `layer_collection`
      selects a previously registered block with the same key as the key
      derived from `params` of that block. If False, a new block is
      registered.
    batch_size: A `int` representing the batch size. Needs to specified if
      registering generic variables that don't match any layer patterns or
      if time/uses is folded. If the time/uses dimension is merged with
      batch then this is used to infer number of uses/time-steps.

  Raises:
    ValueError: If record_list_dict contains multiple record types for a single
      set of variables, or if there are multiple records for a set of variables
      of a type that doesn't support shared parameters.
    AmbiguousRegistrationError: If a batch norm layer registration is required
      but batch_size is not passed.
  """

  mixed_record_type_errors = []

  # TODO(b/69627702): Layers must be registered in a deterministic order, else
  # FisherFactors may end up with different variable names.
  params_list = sorted(record_list_dict.keys(), key=str)
  for params in params_list:
    record_list = record_list_dict[params]
    # We don't support mixed types for the same params and probably never
    # will.
    if not all(record_list[0].record_type == record.record_type
               for record in record_list):
      mixed_record_type_errors.append(
          'Detected variables {} with mixed record types: {}.'.format(
              params, record_list))
      continue

    record_type = record_list[0].record_type
    if batch_size:
      # If the batch/time dimension is merged in the input then need to set
      # `num_uses`.
      first_dim = record_list[0].data['inputs'].shape.as_list()[0]
      is_batch_time_folded = not (first_dim is None or first_dim == batch_size)
      if is_batch_time_folded:
        num_uses = first_dim // batch_size
        if num_uses == 0:
          raise ValueError('It looks like the batch_size passed to the auto-'
                           'registration function was larger than expected. '
                           'The likely cause of this is that you passed in '
                           'the overall batch size instead of the per-replica '
                           'batch size. When using K-FAC with replication all '
                           'batch sizes passed to K-FAC and its helper modules '
                           'should be their per-replica sizes.')

    if record_type is RecordType.fully_connected:
      if len(record_list) > 1:
        logging.info(
            'Registering as multi fully-connected: {}'.format(params))

        inputs = tuple(record.data['inputs'] for record in record_list)
        outputs = tuple(record.data['outputs'] for record in record_list)
        layer_collection.register_fully_connected_multi(
            params, inputs, outputs, reuse=reuse)
      else:
        logging.info('Registering as fully-connected: {}'.format(params))

        record = record_list[0]
        inputs = record.data['inputs']
        outputs = record.data['outputs']
        if batch_size and is_batch_time_folded:
          layer_collection.register_fully_connected_multi(
              params, inputs, outputs, num_uses=num_uses, reuse=reuse)
        else:
          layer_collection.register_fully_connected(
              params, inputs, outputs, reuse=reuse)

    elif record_type is RecordType.conv2d:
      if len(record_list) > 1:
        logging.info('Registering as multi conv2d: {}'.format(params))

        inputs = tuple(record.data['inputs'] for record in record_list)
        outputs = tuple(record.data['outputs'] for record in record_list)
        strides = record_list[0].data['strides']
        padding = record_list[0].data['padding']
        data_format = record_list[0].data['data_format']
        layer_collection.register_conv2d_multi(
            params,
            strides,
            padding,
            inputs,
            outputs,
            data_format=data_format,
            reuse=reuse)
      else:
        logging.info('Registering as conv2d: {}'.format(params))

        record = record_list[0]
        inputs = record.data['inputs']
        outputs = record.data['outputs']
        strides = record.data['strides']
        padding = record.data['padding']
        data_format = record.data['data_format']
        if batch_size and is_batch_time_folded:
          layer_collection.register_conv2d_multi(
              params,
              strides,
              padding,
              inputs,
              outputs,
              data_format=data_format,
              num_uses=num_uses,
              reuse=reuse)
        else:
          layer_collection.register_conv2d(params, strides, padding, inputs,
                                           outputs, data_format=data_format,
                                           reuse=reuse)

    elif record_type is RecordType.scale_and_shift:
      logging.info('Registering as scale (& shift): {}'.format(params))

      if len(record_list) > 1:
        raise ValueError('Multi-use registrations currently not supported for '
                         'scale & shift operations.')
      record = record_list[0]
      inputs = record.data['inputs']
      outputs = record.data['outputs']

      layer_collection.register_scale_and_shift(params, inputs, outputs,
                                                reuse=reuse)

    elif record_type is RecordType.batch_norm:
      logging.info('Registering as batch norm: {}'.format(params))

      if batch_size is None:
        raise AmbiguousRegistrationError(
            'Tried to register a batch norm layer (as generic) without '
            'knowledge of batch_size. You can pass batch_size in to fix this '
            'error.')

      # This is a slight hack. Ideally register_generic would work with lists
      # of params like it used to before we switched to the "unflattened" cov
      # representation so we wouldn't need to detect the approximation type.
      will_use_diag = (
          layer_collection._get_linked_approx(params) == 'diagonal'  # pylint: disable=protected-access
          or (layer_collection.default_generic_approximation == 'diagonal'
              and layer_collection._get_linked_approx(params) is None)  # pylint: disable=protected-access
          )
      if will_use_diag:
        layer_collection.register_generic(params[0], batch_size, reuse=reuse)
        layer_collection.register_generic(params[1], batch_size, reuse=reuse)
      else:
        layer_collection.register_generic(params, batch_size, reuse=reuse)

    else:
      assert False, 'Invalid record type {}'.format(record_type)

  if mixed_record_type_errors:
    raise ValueError('\n'.join(mixed_record_type_errors))