in tensorflow_transform/tf_metadata/schema_utils.py [0:0]
def schema_as_feature_spec(
schema_proto: schema_pb2.Schema) -> SchemaAsFeatureSpecResult:
"""Generates a feature spec from a Schema proto.
For a Feature with a FixedShape we generate a FixedLenFeature with no default.
For a Feature without a FixedShape we generate a VarLenFeature. For a
SparseFeature we generate a SparseFeature.
Args:
schema_proto: A Schema proto.
Returns:
A pair (feature spec, domains) where feature spec is a dict whose keys are
feature names and values are instances of FixedLenFeature, VarLenFeature
or SparseFeature, and `domains` is a dict whose keys are feature names
and values are one of the `domain_info` oneof, e.g. IntDomain.
Raises:
ValueError: If the schema proto is invalid.
"""
for feature in schema_proto.feature:
if RAGGED_TENSOR_TAG in feature.annotation.tag:
raise ValueError(
'Feature "{}" had tag "{}". Features represented by a '
'RaggedTensor cannot be serialized/deserialized to Example proto or '
'other formats, and cannot have a feature spec generated for '
'them.'.format(feature.name, RAGGED_TENSOR_TAG))
if schema_utils_legacy.get_generate_legacy_feature_spec(schema_proto):
return _legacy_schema_as_feature_spec(schema_proto)
feature_spec = {}
# Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For
# sparse features, will hold the domain_info of the values feature. Features
# that do not have a domain set will not be present in `domains`.
domains = {}
feature_by_name = {feature.name: feature for feature in schema_proto.feature}
string_domains = _get_string_domains(schema_proto)
# Generate a `tf.SparseFeature` for each element of
# `schema_proto.sparse_feature`. This also removed the features from
# feature_by_name.
# TODO(KesterTong): Allow sparse features to share index features.
for feature in schema_proto.sparse_feature:
if _include_in_parsing_spec(feature):
feature_spec[feature.name], domains[feature.name] = (
_sparse_feature_as_feature_spec(feature, feature_by_name,
string_domains))
# Handle ragged `TensorRepresentation`s.
tensor_representations = (
tensor_representation_util.GetTensorRepresentationsFromSchema(
schema_proto, TENSOR_REPRESENTATION_GROUP))
if tensor_representations is not None:
for name, tensor_representation in tensor_representations.items():
if name in feature_by_name:
raise ValueError(
'Ragged TensorRepresentation name "{}" conflicts with a different '
'feature in the same schema.'.format(name))
(feature_spec[name], domains[name]) = (
_ragged_tensor_representation_as_feature_spec(name,
tensor_representation,
feature_by_name,
string_domains))
# Generate a `tf.FixedLenFeature` or `tf.VarLenFeature` for each element of
# `schema_proto.feature` that was not referenced by a `SparseFeature` or a
# ragged `TensorRepresentation`.
for name, feature in feature_by_name.items():
if _include_in_parsing_spec(feature):
feature_spec[name], domains[name] = _feature_as_feature_spec(
feature, string_domains)
schema_utils_legacy.check_for_unsupported_features(schema_proto)
domains = {
name: domain for name, domain in domains.items() if domain is not None
}
return SchemaAsFeatureSpecResult(feature_spec, domains)