def build()

in tensorflow_lattice/python/pwl_calibration_layer.py [0:0]


  def build(self, input_shape):
    """Standard Keras build() method."""
    input_keypoints = np.array(self.input_keypoints)
    # Don't need last keypoint for interpolation because we need only beginnings
    # of intervals.
    if self.input_keypoints_type == "fixed":
      self._interpolation_keypoints = tf.constant(
          input_keypoints[:-1],
          dtype=self.dtype,
          name=INTERPOLATION_KEYPOINTS_NAME)
      self._lengths = tf.constant(
          input_keypoints[1:] - input_keypoints[:-1],
          dtype=self.dtype,
          name=LENGTHS_NAME)
    else:
      self._keypoint_min = input_keypoints[0]
      self._keypoint_range = input_keypoints[-1] - input_keypoints[0]
      # Logits are initialized such that they will recover the scaled keypoint
      # gaps in input_keypoints.
      initial_logits = np.log(
          (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range)
      tiled_logits = np.tile(initial_logits, self.units)
      self.interpolation_logits = self.add_weight(
          INTERPOLATION_LOGITS_NAME,
          shape=[self.units, len(input_keypoints) - 1],
          initializer=tf.constant_initializer(tiled_logits),
          dtype=self.dtype)

    constraints = PWLCalibrationConstraints(
        monotonicity=self.monotonicity,
        convexity=self.convexity,
        lengths=self._lengths if self.input_keypoints_type == "fixed" else None,
        output_min=self.output_min,
        output_max=self.output_max,
        output_min_constraints=self._output_min_constraints,
        output_max_constraints=self._output_max_constraints,
        num_projection_iterations=self.num_projection_iterations)

    if not self.kernel_regularizer:
      kernel_reg = None
    elif len(self.kernel_regularizer) == 1:
      kernel_reg = self.kernel_regularizer[0]
    else:
      # Keras interface assumes only one regularizer, so summ all regularization
      # losses which we have.
      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])

    # If 'is_cyclic' is specified - last weight will be computed from previous
    # weights in order to connect last keypoint with first.
    num_weights = input_keypoints.size - self.is_cyclic

    # PWL calibration layer kernel is units-column matrix. First row of matrix
    # represents bias. All remaining represent delta in y-value compare to
    # previous point. Aka heights of segments.
    self.kernel = self.add_weight(
        PWL_CALIBRATION_KERNEL_NAME,
        shape=[num_weights, self.units],
        initializer=self.kernel_initializer,
        regularizer=kernel_reg,
        constraint=constraints,
        dtype=self.dtype)

    if self.kernel_regularizer and not tf.executing_eagerly():
      # Keras has its own mechanism to handle regularization losses which
      # does not use GraphKeys, but we want to also add losses to graph keys so
      # they are easily accessable when layer is being used outside of Keras.
      # Adding losses to GraphKeys will not interfer with Keras.
      for reg in self.kernel_regularizer:
        tf.compat.v1.add_to_collection(
            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel))

    if self.impute_missing:
      if self.missing_input_value is not None:
        self._missing_input_value_tensor = tf.constant(
            self.missing_input_value,
            dtype=self.dtype,
            name=MISSING_INPUT_VALUE_NAME)
      else:
        self._missing_input_value_tensor = None

      if self.missing_output_value is not None:
        self.missing_output = tf.constant(
            self.missing_output_value, shape=[1, self.units], dtype=self.dtype)
      else:
        missing_init = (self._output_init_min + self._output_init_max) / 2.0
        missing_constraints = NaiveBoundsConstraints(
            lower_bound=self.output_min, upper_bound=self.output_max)
        self.missing_output = self.add_weight(
            PWL_CALIBRATION_MISSING_OUTPUT_NAME,
            shape=[1, self.units],
            initializer=keras.initializers.Constant(value=missing_init),
            constraint=missing_constraints,
            dtype=self.dtype)

    super(PWLCalibration, self).build(input_shape)