def _infer_feature_schema_common()

in tensorflow_transform/schema_inference.py [0:0]


def _infer_feature_schema_common(
    features: Mapping[str, common_types.TensorType],
    tensor_ranges: Mapping[str, Tuple[int, int]],
    feature_annotations: Mapping[str, List[any_pb2.Any]],
    global_annotations: List[any_pb2.Any],
    is_evaluation_complete: bool) -> schema_pb2.Schema:
  """Given a dict of tensors, creates a `Schema`.

  Args:
    features: A dict mapping column names to `Tensor`, `SparseTensor` or
      `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have
      a 0'th dimension which is interpreted as the batch dimension.
    tensor_ranges: A dict mapping a tensor to a tuple containing its min and max
      value.
    feature_annotations: dictionary from feature name to list of any_pb2.Any
      protos to be added as an annotation for that feature in the schema.
    global_annotations: list of any_pb2.Any protos to be added at the global
      schema level.
    is_evaluation_complete: A boolean indicating whether all analyzers have been
      evaluated or not.

  Returns:
    A `Schema` proto.
  """
  domains = {}
  feature_tags = collections.defaultdict(list)
  for name, tensor in features.items():
    if (isinstance(tensor, tf.RaggedTensor) and
        not common_types.is_ragged_feature_available()):
      # Add the 'ragged_tensor' tag which will cause coder and
      # schema_as_feature_spec to raise an error, as there is no feature spec
      # for ragged tensors in TF 1.x.
      feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
    if name in tensor_ranges:
      min_value, max_value = tensor_ranges[name]
      domains[name] = schema_pb2.IntDomain(
          min=min_value, max=max_value, is_categorical=True)
  feature_spec = _feature_spec_from_batched_tensors(features,
                                                    is_evaluation_complete)

  schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)

  # Add the annotations to the schema.
  for annotation in global_annotations:
    schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
  # Build a map from logical feature names to Feature protos
  feature_protos_by_name = {}
  for feature in schema_proto.feature:
    feature_protos_by_name[feature.name] = feature
  for sparse_feature in schema_proto.sparse_feature:
    for index_feature in sparse_feature.index_feature:
      feature_protos_by_name.pop(index_feature.name)
    value_feature = feature_protos_by_name.pop(
        sparse_feature.value_feature.name)
    feature_protos_by_name[sparse_feature.name] = value_feature

  # Handle ragged tensor representations.
  tensor_representations = (
      tensor_representation_util.GetTensorRepresentationsFromSchema(
          schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP))
  if tensor_representations is not None:
    for name, tensor_representation in tensor_representations.items():
      feature_protos_by_name[name] = schema_utils.pop_ragged_source_columns(
          name, tensor_representation, feature_protos_by_name)

  # Update annotations
  for feature_name, annotations in feature_annotations.items():
    feature_proto = feature_protos_by_name[feature_name]
    for annotation in annotations:
      feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
  for feature_name, tags in feature_tags.items():
    feature_proto = feature_protos_by_name[feature_name]
    for tag in tags:
      feature_proto.annotation.tag.append(tag)
  return schema_proto