def schema_as_feature_spec()

in tensorflow_transform/tf_metadata/schema_utils.py [0:0]


def schema_as_feature_spec(
    schema_proto: schema_pb2.Schema) -> SchemaAsFeatureSpecResult:
  """Generates a feature spec from a Schema proto.

  For a Feature with a FixedShape we generate a FixedLenFeature with no default.
  For a Feature without a FixedShape we generate a VarLenFeature.  For a
  SparseFeature we generate a SparseFeature.

  Args:
    schema_proto: A Schema proto.

  Returns:
    A pair (feature spec, domains) where feature spec is a dict whose keys are
        feature names and values are instances of FixedLenFeature, VarLenFeature
        or SparseFeature, and `domains` is a dict whose keys are feature names
        and values are one of the `domain_info` oneof, e.g. IntDomain.

  Raises:
    ValueError: If the schema proto is invalid.
  """
  for feature in schema_proto.feature:
    if RAGGED_TENSOR_TAG in feature.annotation.tag:
      raise ValueError(
          'Feature "{}" had tag "{}".  Features represented by a '
          'RaggedTensor cannot be serialized/deserialized to Example proto or '
          'other formats, and cannot have a feature spec generated for '
          'them.'.format(feature.name, RAGGED_TENSOR_TAG))

  if schema_utils_legacy.get_generate_legacy_feature_spec(schema_proto):
    return _legacy_schema_as_feature_spec(schema_proto)
  feature_spec = {}
  # Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature.  For
  # sparse features, will hold the domain_info of the values feature.  Features
  # that do not have a domain set will not be present in `domains`.
  domains = {}
  feature_by_name = {feature.name: feature for feature in schema_proto.feature}
  string_domains = _get_string_domains(schema_proto)

  # Generate a `tf.SparseFeature` for each element of
  # `schema_proto.sparse_feature`.  This also removed the features from
  # feature_by_name.
  # TODO(KesterTong): Allow sparse features to share index features.
  for feature in schema_proto.sparse_feature:
    if _include_in_parsing_spec(feature):
      feature_spec[feature.name], domains[feature.name] = (
          _sparse_feature_as_feature_spec(feature, feature_by_name,
                                          string_domains))

  # Handle ragged `TensorRepresentation`s.
  tensor_representations = (
      tensor_representation_util.GetTensorRepresentationsFromSchema(
          schema_proto, TENSOR_REPRESENTATION_GROUP))
  if tensor_representations is not None:
    for name, tensor_representation in tensor_representations.items():
      if name in feature_by_name:
        raise ValueError(
            'Ragged TensorRepresentation name "{}" conflicts with a different '
            'feature in the same schema.'.format(name))
      (feature_spec[name], domains[name]) = (
          _ragged_tensor_representation_as_feature_spec(name,
                                                        tensor_representation,
                                                        feature_by_name,
                                                        string_domains))

  # Generate a `tf.FixedLenFeature` or `tf.VarLenFeature` for each element of
  # `schema_proto.feature` that was not referenced by a `SparseFeature` or a
  # ragged `TensorRepresentation`.
  for name, feature in feature_by_name.items():
    if _include_in_parsing_spec(feature):
      feature_spec[name], domains[name] = _feature_as_feature_spec(
          feature, string_domains)

  schema_utils_legacy.check_for_unsupported_features(schema_proto)

  domains = {
      name: domain for name, domain in domains.items() if domain is not None
  }
  return SchemaAsFeatureSpecResult(feature_spec, domains)