def call()

in tensorflow_lattice/python/pwl_calibration_layer.py [0:0]


  def call(self, inputs):
    """Standard Keras call() method..

    Args:
      inputs: Either input tensor or list of 2 elements: input tensor and
        `is_missing` tensor.

    Returns:
      Calibrated input tensor.

    Raises:
      ValueError: If `is_missing` tensor specified incorrectly.
    """
    is_missing = None
    if isinstance(inputs, list):
      # Only 2 element lists are allowed. When such list is given - second
      # element represents 'is_missing' tensor encoded as float value.
      if not self.impute_missing:
        raise ValueError("Multiple inputs for PWLCalibration layer assume "
                         "regular input tensor and 'is_missing' tensor, but "
                         "this instance of a layer is not configured to handle "
                         "missing value. See 'impute_missing' parameter.")
      if len(inputs) > 2:
        raise ValueError("Multiple inputs for PWLCalibration layer assume "
                         "normal input tensor and 'is_missing' tensor, but more"
                         " than 2 tensors given. 'inputs': " + str(inputs))
      if len(inputs) == 2:
        inputs, is_missing = inputs
        if is_missing.shape.as_list() != inputs.shape.as_list():
          raise ValueError(
              "is_missing shape %s does not match inputs shape %s for "
              "PWLCalibration layer" %
              (str(is_missing.shape), str(inputs.shape)))
      else:
        [inputs] = inputs
    if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and
                                  inputs.shape[1] != 1):
      raise ValueError("Shape of input tensor for PWLCalibration layer must be "
                       "[-1, units] or [-1, 1]. It is: " + str(inputs.shape))

    if self.input_keypoints_type == "fixed":
      keypoints_dtype = self._interpolation_keypoints.dtype
    else:
      keypoints_dtype = self.interpolation_logits.dtype
    if inputs.dtype != keypoints_dtype:
      raise ValueError("dtype(%s) of input to PWLCalibration layer does not "
                       "correspond to dtype(%s) of keypoints. You can enforce "
                       "dtype of keypoints by explicitly providing 'dtype' "
                       "parameter to layer constructor or by passing keypoints "
                       "in such format which by default will be converted into "
                       "desired one." % (inputs.dtype, keypoints_dtype))

    # Here is calibration. Everything else is handling of missing.
    if inputs.shape[1] > 1 or (self.input_keypoints_type == "learned_interior"
                               and self.units > 1):
      # Interpolation will have shape [batch_size, units, weights] in these
      # cases. To prepare for that, we add a dimension to the input here to get
      # shape [batch_size, units, 1] or [batch_size, 1, 1] if 1d input.
      inputs_to_calibration = tf.expand_dims(inputs, -1)
    else:
      inputs_to_calibration = inputs
    if self.input_keypoints_type == "learned_interior":
      self._lengths = tf.multiply(
          tf.nn.softmax(self.interpolation_logits, axis=1),
          self._keypoint_range,
          name=LENGTHS_NAME)
      self._interpolation_keypoints = tf.add(
          tf.cumsum(self._lengths, axis=1, exclusive=True),
          self._keypoint_min,
          name=INTERPOLATION_KEYPOINTS_NAME)
    interpolation_weights = pwl_calibration_lib.compute_interpolation_weights(
        inputs_to_calibration, self._interpolation_keypoints, self._lengths)
    if self.is_cyclic:
      # Need to add such last height to make all heights to sum up to 0.0 in
      # order to make calibrator cyclic.
      bias_and_heights = tf.concat(
          [self.kernel, -tf.reduce_sum(self.kernel[1:], axis=0, keepdims=True)],
          axis=0)
    else:
      bias_and_heights = self.kernel

    # bias_and_heights has shape [weight, units].
    if len(interpolation_weights.shape) > 2:
      # Multi dim input has interpolation shape [batch_size, units, weights].
      result = tf.reduce_sum(
          interpolation_weights * tf.transpose(bias_and_heights), axis=-1)
    else:
      # Single dim input has interpolation shape [batch_size, weights].
      result = tf.matmul(interpolation_weights, bias_and_heights)

    if self.impute_missing:
      if is_missing is None:
        if self.missing_input_value is None:
          raise ValueError("PWLCalibration layer is configured to impute "
                           "missing but no 'missing_input_value' specified and "
                           "'is_missing' tensor is not given.")
        assert self._missing_input_value_tensor is not None
        is_missing = tf.cast(
            tf.equal(inputs, self._missing_input_value_tensor),
            dtype=self.dtype)
      result = is_missing * self.missing_output + (1.0 - is_missing) * result

    if self.units > 1 and self.split_outputs:
      result = tf.split(result, self.units, axis=1)

    return result