in tensorflow_transform/mappers.py [0:0]
def _scale_to_z_score_internal(
x: common_types.ConsistentTensorType,
key: Optional[common_types.TensorType], elementwise: bool,
key_vocabulary_filename: Optional[str],
output_dtype: Optional[tf.DType]) -> common_types.ConsistentTensorType:
"""Implementation for scale_to_z_score."""
# x_mean will be float16, float32, or float64, depending on type of x
if key is None:
x_mean, x_var = analyzers._mean_and_var( # pylint: disable=protected-access
x,
reduce_instance_dims=not elementwise,
output_dtype=output_dtype)
else:
if elementwise:
raise NotImplementedError('Per-key elementwise reduction not supported')
mean_and_var_per_key_result = analyzers._mean_and_var_per_key( # pylint: disable=protected-access
x, key, key_vocabulary_filename=key_vocabulary_filename,
output_dtype=output_dtype)
if key_vocabulary_filename is None:
# Missing keys will translate to 0 for both mean and var which will be
# ignored below in the tf.where.
key_vocab, key_means, key_vars = mean_and_var_per_key_result
x_mean, x_var = tf_utils.map_per_key_reductions((key_means, key_vars),
key, key_vocab, x)
else:
mean_var_for_key = tf_utils.apply_per_key_vocabulary(
mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims)
x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1])
compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
x_values = x
if isinstance(x, tf.SparseTensor):
x_values = x.values
if elementwise:
x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices)
x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices)
elif isinstance(x, tf.RaggedTensor):
if elementwise:
raise NotImplementedError(
'Elementwise scale_to_z_score does not support RaggedTensors')
x_values = x.flat_values
numerator = tf.cast(x_values, x_mean.dtype) - x_mean
denominator = tf.sqrt(x_var)
cond = tf.not_equal(denominator, 0)
if cond.shape.as_list() != x_values.shape.as_list():
# Repeats cond when necessary across the batch dimension for it to be
# compatible with the shape of numerator.
cond = tf.cast(
tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype),
dtype=tf.bool)
deviation_values = tf.where(cond, tf.divide(numerator, denominator),
numerator)
return compose_result_fn(deviation_values)