def _make_cast_fn()

in tensorflow_transform/coders/example_proto_coder.py [0:0]


def _make_cast_fn(np_dtype):
  """Return a function to extract the typed value from the feature.

  For performance reasons it is preferred to have the cast fn
  constructed once (for each handler).

  Args:
    np_dtype: The numpy type of the Tensorflow feature.

  Returns:
    A function to extract the value field from a string depending on dtype.
  """

  # There seems to be a great degree of variability for handling automatic
  # conversions across types and across API implementation of the Python
  # protocol buffer library.
  #
  # For the 'python' implementation we need to always "cast" from np types to
  # the appropriate Python type.
  #
  # For the 'cpp' implementation we need to only "cast" from np types to the
  # appropriate Python type for "Float" types, but only for protobuf < 3.2.0

  def identity(x):
    return x

  def numeric_cast(x):
    if isinstance(x, (np.generic, np.ndarray)):
      # This works for both np.generic and np.array (of any shape).
      return x.tolist()

    # This works for python scalars (or lists thereof), which require no
    # casting.
    return x

  # This is in agreement with Tensorflow conversions for Unicode values for both
  # Python 2 and 3 (and also works for non-Unicode objects). It is also in
  # agreement with the testWithUnicode of the Beam impl.
  def utf8(s):
    return s if isinstance(s, bytes) else s.encode('utf-8')

  vectorize = np.vectorize(utf8)

  def string_cast(x):
    if isinstance(x, list) or isinstance(x, np.ndarray) and x.ndim > 0:
      return map(utf8, x)
    elif isinstance(x, np.ndarray):
      return vectorize(x).tolist()
    return utf8(x)

  if issubclass(np_dtype, np.floating):
    try:
      float_list = tf.train.FloatList()
      float_list.value.append(np.float32(0.1))  # Any dummy value will do.
      float_list.value.append(np.array(0.1))  # Any dummy value will do.
      float_list.value.extend(np.array([0.1, 0.2]))  # Any dummy values will do.
      return identity
    except TypeError:
      return numeric_cast
  elif issubclass(np_dtype, np.integer):
    try:
      int64_list = tf.train.Int64List()
      int64_list.value.append(np.int64(1))  # Any dummy value will do.
      int64_list.value.append(np.array(1))  # Any dummy value will do.
      int64_list.value.extend(np.array([1, 2]))  # Any dummy values will do.
      return identity
    except TypeError:
      return numeric_cast

  return string_cast