def set_domain()

in tensorflow_data_validation/utils/schema_util.py [0:0]


def set_domain(schema: schema_pb2.Schema, feature_path: types.FeaturePath,
               domain: Any) -> None:
  """Sets the domain for the input feature in the schema.

  If the input feature already has a domain, it is overwritten with the newly
  provided input domain. This method cannot be used to add a new global domain.

  Args:
    schema: A Schema protocol buffer.
    feature_path: The name of the feature whose domain needs to be set. If a
      FeatureName is passed, a one-step FeaturePath will be constructed and
      used. For example, "my_feature" -> types.FeaturePath(["my_feature"])
    domain: A domain protocol buffer or the name of a global string domain
      present in the input schema.
  Example:  ```python >>> from tensorflow_metadata.proto.v0 import schema_pb2
    >>> import tensorflow_data_validation as tfdv >>> schema =
    schema_pb2.Schema() >>> schema.feature.add(name='feature') # Setting a int
    domain. >>> int_domain = schema_pb2.IntDomain(min=3, max=5) >>>
    tfdv.set_domain(schema, "feature", int_domain) # Setting a string domain.
    >>> str_domain = schema_pb2.StringDomain(value=['one', 'two', 'three']) >>>
    tfdv.set_domain(schema, "feature", str_domain) ```

  Raises:
    TypeError: If the input schema or the domain is not of the expected type.
    ValueError: If an invalid global string domain is provided as input.
  """
  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)

  # Find all fields types and names within domain_info.
  feature_domains = {}
  for f in schema_pb2.Feature.DESCRIPTOR.oneofs_by_name['domain_info'].fields:
    if f.message_type is not None:
      feature_domains[getattr(schema_pb2, f.message_type.name)] = f.name
    elif f.type == descriptor.FieldDescriptor.TYPE_STRING:
      feature_domains[str] = f.name
    else:
      raise TypeError('Unexpected type within schema.Features.domain_info')
  if not isinstance(domain, tuple(feature_domains.keys())):
    raise TypeError('domain is of type %s, should be one of the supported types'
                    ' in schema.Features.domain_info' % type(domain).__name__)

  feature = get_feature(schema, feature_path)
  if feature.type == schema_pb2.STRUCT:
    raise TypeError('Could not set the domain of a STRUCT feature %s.' %
                    feature_path)

  if feature.WhichOneof('domain_info') is not None:
    logging.warning('Replacing existing domain of feature "%s".', feature_path)

  for d_type, d_name in feature_domains.items():
    if isinstance(domain, d_type):
      if d_type == str:
        found_domain = False
        for global_domain in schema.string_domain:
          if global_domain.name == domain:
            found_domain = True
            break
        if not found_domain:
          raise ValueError('Invalid global string domain "{}".'.format(domain))
        feature.domain = domain
      else:
        getattr(feature, d_name).CopyFrom(domain)