in tensorflow_datasets/core/example_parser.py [0:0]
def _to_tf_example_spec(tensor_info):
"""Convert a `TensorInfo` into a feature proto object."""
# Convert the dtype
# TODO(b/119937875): TF Examples proto only support int64, float32 and string
# This create limitation like float64 downsampled to float32, bool converted
# to int64 which is space ineficient, no support for complexes or quantized
# It seems quite space inefficient to convert bool to int64
if tensor_info.dtype.is_integer or tensor_info.dtype.is_bool:
dtype = tf.int64
elif tensor_info.dtype.is_floating:
dtype = tf.float32
elif tensor_info.dtype == tf.string:
dtype = tf.string
else:
# TFRecord only support 3 types
raise NotImplementedError(
"Serialization not implemented for dtype {}".format(tensor_info))
# Convert the shape
# Select the feature proto type in function of the unknown shape
if all(s is not None for s in tensor_info.shape):
return tf.io.FixedLenFeature( # All shaped defined
shape=tensor_info.shape,
dtype=dtype,
default_value=tensor_info.default_value,
)
elif tensor_info.shape.count(None) == 1:
# Extract the defined shape (without the None dimension)
# The original shape is restored in `_deserialize_single_field`
shape = tuple(dim for dim in tensor_info.shape if dim is not None)
return tf.io.FixedLenSequenceFeature( # First shape undefined
shape=shape,
dtype=dtype,
allow_missing=True,
default_value=tensor_info.default_value,
)
elif tensor_info.sequence_rank > 1: # RaggedTensor
# Decoding here should match encoding from `_add_ragged_fields` in
# `example_serializer.py`
tf_specs = { # pylint: disable=g-complex-comprehension
"ragged_row_lengths_{}".format(k): tf.io.FixedLenSequenceFeature( # pylint: disable=g-complex-comprehension
shape=(),
dtype=tf.int64,
allow_missing=True,
) for k in range(tensor_info.sequence_rank - 1)
}
tf_specs["ragged_flat_values"] = tf.io.FixedLenSequenceFeature(
shape=tensor_info.shape[tensor_info.sequence_rank:],
dtype=dtype,
allow_missing=True,
default_value=tensor_info.default_value,
)
return tf_specs
else:
raise NotImplementedError(
"Multiple unknown dimension not supported.\n"
"If using `tfds.features.Tensor`, please set "
"`Tensor(..., encoding='zlib')` (or 'bytes', or 'gzip')")