in tensorflow_data_validation/utils/schema_util.py [0:0]
def set_domain(schema: schema_pb2.Schema, feature_path: types.FeaturePath,
domain: Any) -> None:
"""Sets the domain for the input feature in the schema.
If the input feature already has a domain, it is overwritten with the newly
provided input domain. This method cannot be used to add a new global domain.
Args:
schema: A Schema protocol buffer.
feature_path: The name of the feature whose domain needs to be set. If a
FeatureName is passed, a one-step FeaturePath will be constructed and
used. For example, "my_feature" -> types.FeaturePath(["my_feature"])
domain: A domain protocol buffer or the name of a global string domain
present in the input schema.
Example: ```python >>> from tensorflow_metadata.proto.v0 import schema_pb2
>>> import tensorflow_data_validation as tfdv >>> schema =
schema_pb2.Schema() >>> schema.feature.add(name='feature') # Setting a int
domain. >>> int_domain = schema_pb2.IntDomain(min=3, max=5) >>>
tfdv.set_domain(schema, "feature", int_domain) # Setting a string domain.
>>> str_domain = schema_pb2.StringDomain(value=['one', 'two', 'three']) >>>
tfdv.set_domain(schema, "feature", str_domain) ```
Raises:
TypeError: If the input schema or the domain is not of the expected type.
ValueError: If an invalid global string domain is provided as input.
"""
if not isinstance(schema, schema_pb2.Schema):
raise TypeError('schema is of type %s, should be a Schema proto.' %
type(schema).__name__)
# Find all fields types and names within domain_info.
feature_domains = {}
for f in schema_pb2.Feature.DESCRIPTOR.oneofs_by_name['domain_info'].fields:
if f.message_type is not None:
feature_domains[getattr(schema_pb2, f.message_type.name)] = f.name
elif f.type == descriptor.FieldDescriptor.TYPE_STRING:
feature_domains[str] = f.name
else:
raise TypeError('Unexpected type within schema.Features.domain_info')
if not isinstance(domain, tuple(feature_domains.keys())):
raise TypeError('domain is of type %s, should be one of the supported types'
' in schema.Features.domain_info' % type(domain).__name__)
feature = get_feature(schema, feature_path)
if feature.type == schema_pb2.STRUCT:
raise TypeError('Could not set the domain of a STRUCT feature %s.' %
feature_path)
if feature.WhichOneof('domain_info') is not None:
logging.warning('Replacing existing domain of feature "%s".', feature_path)
for d_type, d_name in feature_domains.items():
if isinstance(domain, d_type):
if d_type == str:
found_domain = False
for global_domain in schema.string_domain:
if global_domain.name == domain:
found_domain = True
break
if not found_domain:
raise ValueError('Invalid global string domain "{}".'.format(domain))
feature.domain = domain
else:
getattr(feature, d_name).CopyFrom(domain)