def _scale_to_z_score_internal()

in tensorflow_transform/mappers.py [0:0]


def _scale_to_z_score_internal(
    x: common_types.ConsistentTensorType,
    key: Optional[common_types.TensorType], elementwise: bool,
    key_vocabulary_filename: Optional[str],
    output_dtype: Optional[tf.DType]) -> common_types.ConsistentTensorType:
  """Implementation for scale_to_z_score."""
  # x_mean will be float16, float32, or float64, depending on type of x
  if key is None:
    x_mean, x_var = analyzers._mean_and_var(  # pylint: disable=protected-access
        x,
        reduce_instance_dims=not elementwise,
        output_dtype=output_dtype)
  else:
    if elementwise:
      raise NotImplementedError('Per-key elementwise reduction not supported')

    mean_and_var_per_key_result = analyzers._mean_and_var_per_key(  # pylint: disable=protected-access
        x, key, key_vocabulary_filename=key_vocabulary_filename,
        output_dtype=output_dtype)

    if key_vocabulary_filename is None:
      # Missing keys will translate to 0 for both mean and var which will be
      # ignored below in the tf.where.
      key_vocab, key_means, key_vars = mean_and_var_per_key_result
      x_mean, x_var = tf_utils.map_per_key_reductions((key_means, key_vars),
                                                      key, key_vocab, x)
    else:
      mean_var_for_key = tf_utils.apply_per_key_vocabulary(
          mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims)
      x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1])

  compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
  x_values = x

  if isinstance(x, tf.SparseTensor):
    x_values = x.values
    if elementwise:
      x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices)
      x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices)
  elif isinstance(x, tf.RaggedTensor):
    if elementwise:
      raise NotImplementedError(
          'Elementwise scale_to_z_score does not support RaggedTensors')
    x_values = x.flat_values

  numerator = tf.cast(x_values, x_mean.dtype) - x_mean
  denominator = tf.sqrt(x_var)
  cond = tf.not_equal(denominator, 0)

  if cond.shape.as_list() != x_values.shape.as_list():
    # Repeats cond when necessary across the batch dimension for it to be
    # compatible with the shape of numerator.
    cond = tf.cast(
        tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype),
        dtype=tf.bool)

  deviation_values = tf.where(cond, tf.divide(numerator, denominator),
                              numerator)
  return compose_result_fn(deviation_values)