in tfx/components/example_validator/executor.py [0:0]
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]
) -> execution_result_pb2.ExecutorOutput:
"""TensorFlow ExampleValidator executor entrypoint.
This validates statistics against the schema.
Args:
input_dict: Input dict from input key to a list of artifacts, including:
- statistics: A list of type `standard_artifacts.ExampleStatistics`
generated by StatisticsGen.
- schema: A list of type `standard_artifacts.Schema` which should
contain a single schema artifact.
output_dict: Output dict from key to a list of artifacts, including:
- output: A list of 'standard_artifacts.ExampleAnomalies' of size one.
It will include a single binary proto file which contains all
anomalies found.
exec_properties: A dict of execution properties.
- exclude_splits: JSON-serialized list of names of splits that the
example validator should not validate.
- custom_validation_config: An optional configuration for specifying
custom validations with SQL.
Returns:
ExecutionResult proto with anomalies
"""
self._log_startup(input_dict, output_dict, exec_properties)
# Load and deserialize exclude splits from execution properties.
exclude_splits = json_utils.loads(
exec_properties.get(standard_component_specs.EXCLUDE_SPLITS_KEY,
'null')) or []
if not isinstance(exclude_splits, list):
raise ValueError('exclude_splits in execution properties needs to be a '
'list. Got %s instead.' % type(exclude_splits))
# Setup output splits.
stats_artifact = artifact_utils.get_single_instance(
input_dict[standard_component_specs.STATISTICS_KEY])
stats_split_names = artifact_utils.decode_split_names(
stats_artifact.split_names)
split_names = [
split for split in stats_split_names if split not in exclude_splits
]
anomalies_artifact = artifact_utils.get_single_instance(
output_dict[standard_component_specs.ANOMALIES_KEY])
anomalies_artifact.split_names = artifact_utils.encode_split_names(
split_names)
anomalies_artifact.span = stats_artifact.span
schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(
input_dict[standard_component_specs.SCHEMA_KEY])))
blessed_value_dict = {}
for split in artifact_utils.decode_split_names(stats_artifact.split_names):
if split in exclude_splits:
continue
logging.info(
'Validating schema against the computed statistics for '
'split %s.', split)
stats = stats_artifact_utils.load_statistics(stats_artifact,
split).proto()
label_inputs = {
standard_component_specs.STATISTICS_KEY:
stats,
standard_component_specs.SCHEMA_KEY:
schema,
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY:
exec_properties.get(
standard_component_specs.CUSTOM_VALIDATION_CONFIG_KEY),
}
output_uri = artifact_utils.get_split_uri(
output_dict[standard_component_specs.ANOMALIES_KEY], split)
label_outputs = {labels.SCHEMA_DIFF_PATH: output_uri}
anomalies = self._Validate(label_inputs, label_outputs)
if anomalies.anomaly_info or anomalies.HasField('dataset_anomaly_info'):
blessed_value_dict[split] = NOT_BLESSED_VALUE
else:
blessed_value_dict[split] = BLESSED_VALUE
logging.info(
'Validation complete for split %s. Anomalies written to '
'%s.', split, output_uri)
# Set blessed custom property for anomalies artifact.
anomalies_artifact.set_json_value_custom_property(
ARTIFACT_PROPERTY_BLESSED_KEY, blessed_value_dict
)
executor_output = execution_result_pb2.ExecutorOutput()
executor_output.output_artifacts[
standard_component_specs.ANOMALIES_KEY
].artifacts.append(anomalies_artifact.mlmd_artifact)
return executor_output