in tensorflow_transform/coders/example_proto_coder.py [0:0]
def _make_cast_fn(np_dtype):
"""Return a function to extract the typed value from the feature.
For performance reasons it is preferred to have the cast fn
constructed once (for each handler).
Args:
np_dtype: The numpy type of the Tensorflow feature.
Returns:
A function to extract the value field from a string depending on dtype.
"""
# There seems to be a great degree of variability for handling automatic
# conversions across types and across API implementation of the Python
# protocol buffer library.
#
# For the 'python' implementation we need to always "cast" from np types to
# the appropriate Python type.
#
# For the 'cpp' implementation we need to only "cast" from np types to the
# appropriate Python type for "Float" types, but only for protobuf < 3.2.0
def identity(x):
return x
def numeric_cast(x):
if isinstance(x, (np.generic, np.ndarray)):
# This works for both np.generic and np.array (of any shape).
return x.tolist()
# This works for python scalars (or lists thereof), which require no
# casting.
return x
# This is in agreement with Tensorflow conversions for Unicode values for both
# Python 2 and 3 (and also works for non-Unicode objects). It is also in
# agreement with the testWithUnicode of the Beam impl.
def utf8(s):
return s if isinstance(s, bytes) else s.encode('utf-8')
vectorize = np.vectorize(utf8)
def string_cast(x):
if isinstance(x, list) or isinstance(x, np.ndarray) and x.ndim > 0:
return map(utf8, x)
elif isinstance(x, np.ndarray):
return vectorize(x).tolist()
return utf8(x)
if issubclass(np_dtype, np.floating):
try:
float_list = tf.train.FloatList()
float_list.value.append(np.float32(0.1)) # Any dummy value will do.
float_list.value.append(np.array(0.1)) # Any dummy value will do.
float_list.value.extend(np.array([0.1, 0.2])) # Any dummy values will do.
return identity
except TypeError:
return numeric_cast
elif issubclass(np_dtype, np.integer):
try:
int64_list = tf.train.Int64List()
int64_list.value.append(np.int64(1)) # Any dummy value will do.
int64_list.value.append(np.array(1)) # Any dummy value will do.
int64_list.value.extend(np.array([1, 2])) # Any dummy values will do.
return identity
except TypeError:
return numeric_cast
return string_cast