in tensorflow_lattice/python/pwl_calibration_layer.py [0:0]
def call(self, inputs):
"""Standard Keras call() method..
Args:
inputs: Either input tensor or list of 2 elements: input tensor and
`is_missing` tensor.
Returns:
Calibrated input tensor.
Raises:
ValueError: If `is_missing` tensor specified incorrectly.
"""
is_missing = None
if isinstance(inputs, list):
# Only 2 element lists are allowed. When such list is given - second
# element represents 'is_missing' tensor encoded as float value.
if not self.impute_missing:
raise ValueError("Multiple inputs for PWLCalibration layer assume "
"regular input tensor and 'is_missing' tensor, but "
"this instance of a layer is not configured to handle "
"missing value. See 'impute_missing' parameter.")
if len(inputs) > 2:
raise ValueError("Multiple inputs for PWLCalibration layer assume "
"normal input tensor and 'is_missing' tensor, but more"
" than 2 tensors given. 'inputs': " + str(inputs))
if len(inputs) == 2:
inputs, is_missing = inputs
if is_missing.shape.as_list() != inputs.shape.as_list():
raise ValueError(
"is_missing shape %s does not match inputs shape %s for "
"PWLCalibration layer" %
(str(is_missing.shape), str(inputs.shape)))
else:
[inputs] = inputs
if len(inputs.shape) != 2 or (inputs.shape[1] != self.units and
inputs.shape[1] != 1):
raise ValueError("Shape of input tensor for PWLCalibration layer must be "
"[-1, units] or [-1, 1]. It is: " + str(inputs.shape))
if self.input_keypoints_type == "fixed":
keypoints_dtype = self._interpolation_keypoints.dtype
else:
keypoints_dtype = self.interpolation_logits.dtype
if inputs.dtype != keypoints_dtype:
raise ValueError("dtype(%s) of input to PWLCalibration layer does not "
"correspond to dtype(%s) of keypoints. You can enforce "
"dtype of keypoints by explicitly providing 'dtype' "
"parameter to layer constructor or by passing keypoints "
"in such format which by default will be converted into "
"desired one." % (inputs.dtype, keypoints_dtype))
# Here is calibration. Everything else is handling of missing.
if inputs.shape[1] > 1 or (self.input_keypoints_type == "learned_interior"
and self.units > 1):
# Interpolation will have shape [batch_size, units, weights] in these
# cases. To prepare for that, we add a dimension to the input here to get
# shape [batch_size, units, 1] or [batch_size, 1, 1] if 1d input.
inputs_to_calibration = tf.expand_dims(inputs, -1)
else:
inputs_to_calibration = inputs
if self.input_keypoints_type == "learned_interior":
self._lengths = tf.multiply(
tf.nn.softmax(self.interpolation_logits, axis=1),
self._keypoint_range,
name=LENGTHS_NAME)
self._interpolation_keypoints = tf.add(
tf.cumsum(self._lengths, axis=1, exclusive=True),
self._keypoint_min,
name=INTERPOLATION_KEYPOINTS_NAME)
interpolation_weights = pwl_calibration_lib.compute_interpolation_weights(
inputs_to_calibration, self._interpolation_keypoints, self._lengths)
if self.is_cyclic:
# Need to add such last height to make all heights to sum up to 0.0 in
# order to make calibrator cyclic.
bias_and_heights = tf.concat(
[self.kernel, -tf.reduce_sum(self.kernel[1:], axis=0, keepdims=True)],
axis=0)
else:
bias_and_heights = self.kernel
# bias_and_heights has shape [weight, units].
if len(interpolation_weights.shape) > 2:
# Multi dim input has interpolation shape [batch_size, units, weights].
result = tf.reduce_sum(
interpolation_weights * tf.transpose(bias_and_heights), axis=-1)
else:
# Single dim input has interpolation shape [batch_size, weights].
result = tf.matmul(interpolation_weights, bias_and_heights)
if self.impute_missing:
if is_missing is None:
if self.missing_input_value is None:
raise ValueError("PWLCalibration layer is configured to impute "
"missing but no 'missing_input_value' specified and "
"'is_missing' tensor is not given.")
assert self._missing_input_value_tensor is not None
is_missing = tf.cast(
tf.equal(inputs, self._missing_input_value_tensor),
dtype=self.dtype)
result = is_missing * self.missing_output + (1.0 - is_missing) * result
if self.units > 1 and self.split_outputs:
result = tf.split(result, self.units, axis=1)
return result