in tfx_bsl/tfxio/tensor_representation_util.py [0:0]
def _LegacyInferTensorRepresentationFromSchema(
schema: schema_pb2.Schema) -> Dict[Text, schema_pb2.TensorRepresentation]:
"""Translate a Feature proto into a TensorRepresentation proto.
This function applies heuristics to deduce the shape and other information
from a FeatureProto. The FeatureProto contains information about the feature
in an ExampleProto, but the feature spec proto also requires enough
information to parse the feature into a tensor. We apply the following rules:
1. The shape and representation of the column are determined by the
following rules:
* if the value_count.min and value_count.max are both 1 then the shape
is scalar and the representation is fixed length.
* If value_count.min and value_count.max are equal but greater than 1,
then the shape is a vector whose length is value_count.max and the
representation is fixed length.
* If value_count.min and value_count.max are equal and are less than 1,
then the shape is a vector of unknown length and the representation
is variable length.
* If value_count.min and value_count.max are not equal then
the shape is a vector of unknown length and the representation is
variable length.
2. If the feature is always present or is variable length (based on the
above rule), no default value is set but if the feature is not always
present and is fixed length, then a canonical default value is chosen
based on _LEGACY_DEFAULT_VALUE_FOR_FEATURE_TYPE.
3. Features that are deprecated are completely ignored and removed.
Args:
schema: A Schema proto.
Returns:
A Dict mapping tensor names to their TensorRepresentations.
Raises:
ValueError: If the feature's type is not supported or the schema is invalid.
"""
result = {}
for feature in schema.feature:
if not _ShouldIncludeFeature(feature):
continue
# Infer canonical tensorflow dtype.
if feature.value_count.min < 0:
raise ValueError(
"Feature {} has value_count.min < 0 (value was {}).".format(
feature.name, feature.value_count.min))
if feature.value_count.max < 0:
raise ValueError(
"Feature {} has value_count.max < 0 (value was {}).".format(
feature.name, feature.value_count.max))
# Use heuristics to infer the shape and representation.
if (feature.value_count.min == feature.value_count.max and
feature.value_count.min == 1):
# Case 1: value_count.min == value_count.max == 1. Infer a DenseTensor
# with rank 0 and a default value.
logging.info(
"Feature %s has value_count.min == value_count.max == 1. Setting to "
"DenseTensor.", feature.name)
result[feature.name] = schema_pb2.TensorRepresentation(
dense_tensor=schema_pb2.TensorRepresentation.DenseTensor(
column_name=feature.name,
shape=schema_pb2.FixedShape(),
default_value=_LegacyInferDefaultValue(feature)))
elif (feature.value_count.min == feature.value_count.max and
feature.value_count.min > 1):
# Case 2: value_count.min == value_count.max > 1. Infer a DenseTensor
# with rank 1 and a default value.
shape = schema_pb2.FixedShape(
dim=[schema_pb2.FixedShape.Dim(size=feature.value_count.min)])
logging.info(
"Feature %s has value_count.min == value_count.max > 1. Setting to "
"DenseTensor.", feature.name)
result[feature.name] = schema_pb2.TensorRepresentation(
dense_tensor=schema_pb2.TensorRepresentation.DenseTensor(
column_name=feature.name,
shape=shape,
default_value=_LegacyInferDefaultValue(feature)))
else:
# Case 3: Either value_count.min != value_count.max or
# value_count.min == value_count.max == 0. Infer a VarLenSparseTensor.
logging.info(
"Feature %s has value_count.min != value_count.max or "
"value_count.min == value_count.max == 0. "
"Setting to VarLenSparseTensor.", feature.name)
result[feature.name] = schema_pb2.TensorRepresentation(
varlen_sparse_tensor=schema_pb2.TensorRepresentation
.VarLenSparseTensor(column_name=feature.name))
return result