def check_required_features()

in tensorflow_gnn/graph/schema_validation.py [0:0]


def check_required_features(requirements: schema_pb2.GraphSchema,
                            actual: schema_pb2.GraphSchema):
  """Checks the requirements of a given schema against another.

  This function is used to enable the specification of required features to a
  function. A function accepting a `GraphTensor` instance can this way document
  what features it is expecting to find on it. The function accepts two schemas:
  a `requirements` schema which describes what the function will attempt to
  fetch and use on the `GraphTensor`, and an `actual` schema instance, which is
  the schema describing the dataset. You can use this in your model code to
  ensure that a dataset contains all the expected node sets, edge sets and
  features that the model uses.

  Note that a dimension with a size of `0` in a feature from the `requirements`
  schema is interpreted specially: it means "accept any value for this
  dimension." The special value `-1` is still used to represent a ragged
  dimension.

  (Finally, note that this function predates the existence of `GraphTensorSpec`,
  which is a runtime descriptor for a `GraphTensor`. We may eventually perovide
  an equivalent construct using the `GraphTensorSpec.)

  Args:
    requirements: An instance of a GraphSchema object, with optional shapes.
    actual: The instance of actual schema to check is a matching superset
      of the required schema.

  Raises:
    ValidationError: If the given schema does not fulfill the requirements.
  """
  # Create maps of the required and provided features.
  def build_schema_map(schema_):
    mapping = {}
    for (set_type, set_name, feature_name,
         feature) in su.iter_features(schema_):
      key = (set_type, set_name, feature_name)
      mapping[key] = feature
    return mapping
  required = build_schema_map(requirements)
  given = build_schema_map(actual)
  for key, required_feature in required.items():
    set_type, set_name, feature_name = key
    try:
      given_feature = given[key]
    except KeyError:
      raise ValidationError(
          "{} feature '{}' from set '{}' is missing from given schema".format(
              set_type.capitalize(), feature_name, set_name))
    else:
      if required_feature.HasField("dtype") and (
          required_feature.dtype != given_feature.dtype):
        raise ValidationError(
            "{} feature '{}' from set '{}' has invalid type: {}".format(
                set_type.capitalize(), feature_name, set_name,
                given_feature.dtype))
      if required_feature.HasField("shape"):
        if len(given_feature.shape.dim) != len(required_feature.shape.dim):
          raise ValidationError(
              "{} feature '{}' from set '{}' has invalid shape: {}".format(
                  set_type.capitalize(), feature_name, set_name,
                  given_feature.shape))
        for required_dim, given_dim in zip(required_feature.shape.dim,
                                           given_feature.shape.dim):
          if required_dim.size == 0:  # Accept any dimension.
            continue
          elif given_dim.size != required_dim.size:
            raise ValidationError(
                "{} feature '{}' from set '{}' has invalid shape: {}".format(
                    set_type.capitalize(), feature_name, set_name,
                    given_feature.shape))