in tensorflow_gnn/graph/schema_validation.py [0:0]
def check_required_features(requirements: schema_pb2.GraphSchema,
actual: schema_pb2.GraphSchema):
"""Checks the requirements of a given schema against another.
This function is used to enable the specification of required features to a
function. A function accepting a `GraphTensor` instance can this way document
what features it is expecting to find on it. The function accepts two schemas:
a `requirements` schema which describes what the function will attempt to
fetch and use on the `GraphTensor`, and an `actual` schema instance, which is
the schema describing the dataset. You can use this in your model code to
ensure that a dataset contains all the expected node sets, edge sets and
features that the model uses.
Note that a dimension with a size of `0` in a feature from the `requirements`
schema is interpreted specially: it means "accept any value for this
dimension." The special value `-1` is still used to represent a ragged
dimension.
(Finally, note that this function predates the existence of `GraphTensorSpec`,
which is a runtime descriptor for a `GraphTensor`. We may eventually perovide
an equivalent construct using the `GraphTensorSpec.)
Args:
requirements: An instance of a GraphSchema object, with optional shapes.
actual: The instance of actual schema to check is a matching superset
of the required schema.
Raises:
ValidationError: If the given schema does not fulfill the requirements.
"""
# Create maps of the required and provided features.
def build_schema_map(schema_):
mapping = {}
for (set_type, set_name, feature_name,
feature) in su.iter_features(schema_):
key = (set_type, set_name, feature_name)
mapping[key] = feature
return mapping
required = build_schema_map(requirements)
given = build_schema_map(actual)
for key, required_feature in required.items():
set_type, set_name, feature_name = key
try:
given_feature = given[key]
except KeyError:
raise ValidationError(
"{} feature '{}' from set '{}' is missing from given schema".format(
set_type.capitalize(), feature_name, set_name))
else:
if required_feature.HasField("dtype") and (
required_feature.dtype != given_feature.dtype):
raise ValidationError(
"{} feature '{}' from set '{}' has invalid type: {}".format(
set_type.capitalize(), feature_name, set_name,
given_feature.dtype))
if required_feature.HasField("shape"):
if len(given_feature.shape.dim) != len(required_feature.shape.dim):
raise ValidationError(
"{} feature '{}' from set '{}' has invalid shape: {}".format(
set_type.capitalize(), feature_name, set_name,
given_feature.shape))
for required_dim, given_dim in zip(required_feature.shape.dim,
given_feature.shape.dim):
if required_dim.size == 0: # Accept any dimension.
continue
elif given_dim.size != required_dim.size:
raise ValidationError(
"{} feature '{}' from set '{}' has invalid shape: {}".format(
set_type.capitalize(), feature_name, set_name,
given_feature.shape))