def build()

in tensorflow_lattice/python/rtl_layer.py [0:0]


  def build(self, input_shape):
    """Standard Keras build() method."""
    rtl_lib.verify_hyperparameters(
        lattice_size=self.lattice_size, input_shape=input_shape)
    # Convert kernel regularizers to proper form (tuples).
    kernel_regularizer = self.kernel_regularizer
    if isinstance(self.kernel_regularizer, list):
      if isinstance(self.kernel_regularizer[0], six.string_types):
        kernel_regularizer = tuple(self.kernel_regularizer)
      else:
        kernel_regularizer = [tuple(r) for r in self.kernel_regularizer]
    self._rtl_structure = self._get_rtl_structure(input_shape)
    # dict from monotonicities to the lattice layers with those monotonicities.
    self._lattice_layers = {}
    for monotonicities, inputs_for_units in self._rtl_structure:
      monotonicities_str = ''.join(
          [str(monotonicity) for monotonicity in monotonicities])
      # Passthrough names for reconstructing model graph.
      inputs_for_units_name = '{}_{}'.format(INPUTS_FOR_UNITS_PREFIX,
                                             monotonicities_str)
      # Use control dependencies to save inputs_for_units as graph constant for
      # visualisation toolbox to be able to recover it from saved graph.
      # Wrap this constant into pure op since in TF 2.0 there are issues passing
      # tensors into control_dependencies.
      with tf.control_dependencies([
          tf.constant(
              inputs_for_units, dtype=tf.int32, name=inputs_for_units_name)
      ]):
        units = len(inputs_for_units)
        if self.parameterization == 'all_vertices':
          layer_name = '{}_{}'.format(RTL_LATTICE_NAME, monotonicities_str)
          lattice_sizes = [self.lattice_size] * self.lattice_rank
          kernel_initializer = lattice_layer.create_kernel_initializer(
              kernel_initializer_id=self.kernel_initializer,
              lattice_sizes=lattice_sizes,
              monotonicities=monotonicities,
              output_min=self.output_min,
              output_max=self.output_max,
              unimodalities=None,
              joint_unimodalities=None,
              init_min=self.init_min,
              init_max=self.init_max)
          self._lattice_layers[str(monotonicities)] = lattice_layer.Lattice(
              lattice_sizes=lattice_sizes,
              units=units,
              monotonicities=monotonicities,
              output_min=self.output_min,
              output_max=self.output_max,
              num_projection_iterations=self.num_projection_iterations,
              monotonic_at_every_step=self.monotonic_at_every_step,
              clip_inputs=self.clip_inputs,
              interpolation=self.interpolation,
              kernel_initializer=kernel_initializer,
              kernel_regularizer=kernel_regularizer,
              name=layer_name,
          )
        elif self.parameterization == 'kronecker_factored':
          layer_name = '{}_{}'.format(RTL_KFL_NAME, monotonicities_str)
          kernel_initializer = kfll.create_kernel_initializer(
              kernel_initializer_id=self.kernel_initializer,
              monotonicities=monotonicities,
              output_min=self.output_min,
              output_max=self.output_max,
              init_min=self.init_min,
              init_max=self.init_max)
          self._lattice_layers[str(
              monotonicities)] = kfll.KroneckerFactoredLattice(
                  lattice_sizes=self.lattice_size,
                  units=units,
                  num_terms=self.num_terms,
                  monotonicities=monotonicities,
                  output_min=self.output_min,
                  output_max=self.output_max,
                  clip_inputs=self.clip_inputs,
                  kernel_initializer=kernel_initializer,
                  scale_initializer='scale_initializer',
                  name=layer_name)
        else:
          raise ValueError('Unknown type of parameterization: {}'.format(
              self.parameterization))
    super(RTL, self).build(input_shape)