def _get_rtl_structure()

in tensorflow_lattice/python/rtl_layer.py [0:0]


  def _get_rtl_structure(self, input_shape):
    """Returns the RTL structure for the given input_shape.

    Args:
      input_shape: Input shape to the layer. Must be a dict matching the format
        described in the layer description.

    Raises:
      ValueError: If the structure is too small to include all the inputs.

    Returns:
      A list of `(monotonicities, lattices)` tuples, where `monotonicities` is
      the tuple of lattice monotonicites, and `lattices` is a list of list of
      indices into the flattened input to the layer.
    """
    if not isinstance(input_shape, dict):
      input_shape = {'unconstrained': input_shape}

    # Calculate the flattened input to the RTL layer. rtl_inputs will be a list
    # of _RTLInput items, each including information about the monotonicity,
    # input group and input index for each input to the layer.
    # The order for flattening should match the order in the call method.
    rtl_inputs = []
    group = 0  # group id for the input
    input_index = 0  # index into the flattened input
    for input_key in sorted(input_shape.keys()):
      shapes = input_shape[input_key]
      if input_key == 'unconstrained':
        monotonicity = 0
      elif input_key == 'increasing':
        monotonicity = 1
      else:
        raise ValueError(
            'Unrecognized key in the input to the RTL layer: {}'.format(
                input_key))

      if not isinstance(shapes, list):
        # Get the shape after a split. See single dense tensor input format in
        # the layer comments.
        shapes = [(shapes[0], 1)] * shapes[1]

      for shape in shapes:
        for _ in range(shape[1]):
          rtl_inputs.append(
              _RTLInput(
                  monotonicity=monotonicity,
                  group=group,
                  input_index=input_index))
          input_index += 1
        group += 1

    total_usage = self.num_lattices * self.lattice_rank
    if total_usage < len(rtl_inputs):
      raise ValueError(
          'RTL layer with {}x{}D lattices is too small to use all the {} input '
          'features'.format(self.num_lattices, self.lattice_rank,
                            len(rtl_inputs)))

    # Repeat the features to fill all the slots in the RTL layer.
    rs = np.random.RandomState(self.random_seed)
    rs.shuffle(rtl_inputs)
    rtl_inputs = rtl_inputs * (1 + total_usage // len(rtl_inputs))
    rtl_inputs = rtl_inputs[:total_usage]
    rs.shuffle(rtl_inputs)

    # Start with random lattices, possibly with repeated groups in lattices.
    lattices = []
    for lattice_index in range(self.num_lattices):
      lattices.append(
          rtl_inputs[lattice_index * self.lattice_rank:(lattice_index + 1) *
                     self.lattice_rank])

    # Swap features between lattices to make sure only a single input from each
    # group is used in each lattice.
    changed = True
    iteration = 0
    while changed and self.avoid_intragroup_interaction:
      if iteration > _MAX_RTL_SWAPS:
        logging.info('Some lattices in the RTL layer might use features from '
                     'the same input group')
        break
      changed = False
      iteration += 1
      for lattice_0, lattice_1 in itertools.combinations(lattices, 2):
        # For every pair of lattices: lattice_0, lattice_1
        for index_0, index_1 in itertools.product(
            range(len(lattice_0)), range(len(lattice_1))):
          # Consider swapping lattice_0[index_0] with lattice_1[index_1]
          rest_lattice_0 = list(lattice_0)
          rest_lattice_1 = list(lattice_1)
          feature_0 = rest_lattice_0.pop(index_0)
          feature_1 = rest_lattice_1.pop(index_1)
          if feature_0.group == feature_1.group:
            continue

          # Swap if a group is repeated and a swap fixes it.
          rest_lattice_groups_0 = list(
              lattice_input.group for lattice_input in rest_lattice_0)
          rest_lattice_groups_1 = list(
              lattice_input.group for lattice_input in rest_lattice_1)
          if ((feature_0.group in rest_lattice_groups_0) and
              (feature_0.group not in rest_lattice_groups_1) and
              (feature_1.group not in rest_lattice_groups_0)):
            lattice_0[index_0], lattice_1[index_1] = (lattice_1[index_1],
                                                      lattice_0[index_0])
            changed = True

    # Arrange into combined lattices layers. Lattices with similar monotonicites
    # can use the same tfl.layers.Lattice layer.
    # Create a dict: monotonicity -> list of list of input indices.
    lattices_for_monotonicities = collections.defaultdict(list)
    for lattice in lattices:
      lattice.sort(key=lambda lattice_input: lattice_input.monotonicity)
      monotonicities = tuple(
          lattice_input.monotonicity for lattice_input in lattice)
      lattice_input_indices = list(
          lattice_input.input_index for lattice_input in lattice)
      lattices_for_monotonicities[monotonicities].append(lattice_input_indices)

    return sorted(lattices_for_monotonicities.items())