in tensorflow_transform/schema_inference.py [0:0]
def _infer_feature_schema_common(
features: Mapping[str, common_types.TensorType],
tensor_ranges: Mapping[str, Tuple[int, int]],
feature_annotations: Mapping[str, List[any_pb2.Any]],
global_annotations: List[any_pb2.Any],
is_evaluation_complete: bool) -> schema_pb2.Schema:
"""Given a dict of tensors, creates a `Schema`.
Args:
features: A dict mapping column names to `Tensor`, `SparseTensor` or
`RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have
a 0'th dimension which is interpreted as the batch dimension.
tensor_ranges: A dict mapping a tensor to a tuple containing its min and max
value.
feature_annotations: dictionary from feature name to list of any_pb2.Any
protos to be added as an annotation for that feature in the schema.
global_annotations: list of any_pb2.Any protos to be added at the global
schema level.
is_evaluation_complete: A boolean indicating whether all analyzers have been
evaluated or not.
Returns:
A `Schema` proto.
"""
domains = {}
feature_tags = collections.defaultdict(list)
for name, tensor in features.items():
if (isinstance(tensor, tf.RaggedTensor) and
not common_types.is_ragged_feature_available()):
# Add the 'ragged_tensor' tag which will cause coder and
# schema_as_feature_spec to raise an error, as there is no feature spec
# for ragged tensors in TF 1.x.
feature_tags[name].append(schema_utils.RAGGED_TENSOR_TAG)
if name in tensor_ranges:
min_value, max_value = tensor_ranges[name]
domains[name] = schema_pb2.IntDomain(
min=min_value, max=max_value, is_categorical=True)
feature_spec = _feature_spec_from_batched_tensors(features,
is_evaluation_complete)
schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains)
# Add the annotations to the schema.
for annotation in global_annotations:
schema_proto.annotation.extra_metadata.add().CopyFrom(annotation)
# Build a map from logical feature names to Feature protos
feature_protos_by_name = {}
for feature in schema_proto.feature:
feature_protos_by_name[feature.name] = feature
for sparse_feature in schema_proto.sparse_feature:
for index_feature in sparse_feature.index_feature:
feature_protos_by_name.pop(index_feature.name)
value_feature = feature_protos_by_name.pop(
sparse_feature.value_feature.name)
feature_protos_by_name[sparse_feature.name] = value_feature
# Handle ragged tensor representations.
tensor_representations = (
tensor_representation_util.GetTensorRepresentationsFromSchema(
schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP))
if tensor_representations is not None:
for name, tensor_representation in tensor_representations.items():
feature_protos_by_name[name] = schema_utils.pop_ragged_source_columns(
name, tensor_representation, feature_protos_by_name)
# Update annotations
for feature_name, annotations in feature_annotations.items():
feature_proto = feature_protos_by_name[feature_name]
for annotation in annotations:
feature_proto.annotation.extra_metadata.add().CopyFrom(annotation)
for feature_name, tags in feature_tags.items():
feature_proto = feature_protos_by_name[feature_name]
for tag in tags:
feature_proto.annotation.tag.append(tag)
return schema_proto