in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py [0:0]
def _WriteMetricsPlotsAndValidations( # pylint: disable=invalid-name
evaluation: evaluator.Evaluation,
output_paths: Dict[str, str],
eval_config: config_pb2.EvalConfig,
add_metrics_callbacks: List[types.AddMetricsCallbackType],
metrics_key: str,
plots_key: str,
attributions_key: str,
validations_key: str,
output_file_format: str,
rubber_stamp: bool = False) -> beam.pvalue.PDone:
"""PTransform to write metrics and plots."""
if output_file_format not in _SUPPORTED_FORMATS:
raise ValueError('only "{}" formats are currently supported but got '
'output_file_format={}'.format(_SUPPORTED_FORMATS,
output_file_format))
def convert_slice_key_to_parquet_dict(
slice_key: metrics_for_slice_pb2.SliceKey) -> _SliceKeyDictPythonType:
single_slice_key_dicts = []
for single_slice_key in slice_key.single_slice_keys:
kind = single_slice_key.WhichOneof('kind')
if not kind:
continue
single_slice_key_dicts.append({kind: getattr(single_slice_key, kind)})
return {_SINGLE_SLICE_KEYS_PARQUET_FIELD_NAME: single_slice_key_dicts}
def convert_to_parquet_columns(
value: Union[metrics_for_slice_pb2.MetricsForSlice,
metrics_for_slice_pb2.PlotsForSlice,
metrics_for_slice_pb2.AttributionsForSlice]
) -> Dict[str, Union[_SliceKeyDictPythonType, bytes]]:
return {
_SLICE_KEY_PARQUET_COLUMN_NAME:
convert_slice_key_to_parquet_dict(value.slice_key),
_SERIALIZED_VALUE_PARQUET_COLUMN_NAME:
value.SerializeToString()
}
if metrics_key in evaluation and constants.METRICS_KEY in output_paths:
metrics = (
evaluation[metrics_key] | 'ConvertSliceMetricsToProto' >> beam.Map(
convert_slice_metrics_to_proto,
add_metrics_callbacks=add_metrics_callbacks))
file_path_prefix = output_paths[constants.METRICS_KEY]
if output_file_format == _PARQUET_FORMAT:
_ = (
metrics
| 'ConvertToParquetColumns' >> beam.Map(convert_to_parquet_columns)
| 'WriteMetricsToParquet' >> beam.io.WriteToParquet(
file_path_prefix=file_path_prefix,
schema=_SLICED_PARQUET_SCHEMA,
file_name_suffix='.' + output_file_format))
elif output_file_format == _TFRECORD_FORMAT:
_ = metrics | 'WriteMetrics' >> beam.io.WriteToTFRecord(
file_path_prefix=file_path_prefix,
shard_name_template=None if output_file_format else '',
file_name_suffix=('.' +
output_file_format if output_file_format else ''),
coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.MetricsForSlice))
else:
raise ValueError(f'Unsupported output file format: {output_file_format}.')
if plots_key in evaluation and constants.PLOTS_KEY in output_paths:
plots = (
evaluation[plots_key] | 'ConvertSlicePlotsToProto' >> beam.Map(
convert_slice_plots_to_proto,
add_metrics_callbacks=add_metrics_callbacks))
file_path_prefix = output_paths[constants.PLOTS_KEY]
if output_file_format == _PARQUET_FORMAT:
_ = (
plots
|
'ConvertPlotsToParquetColumns' >> beam.Map(convert_to_parquet_columns)
| 'WritePlotsToParquet' >> beam.io.WriteToParquet(
file_path_prefix=file_path_prefix,
schema=_SLICED_PARQUET_SCHEMA,
file_name_suffix='.' + output_file_format))
elif output_file_format == _TFRECORD_FORMAT:
_ = plots | 'WritePlotsToTFRecord' >> beam.io.WriteToTFRecord(
file_path_prefix=file_path_prefix,
shard_name_template=None if output_file_format else '',
file_name_suffix=('.' +
output_file_format if output_file_format else ''),
coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.PlotsForSlice))
else:
raise ValueError(f'Unsupported output file format: {output_file_format}.')
if (attributions_key in evaluation and
constants.ATTRIBUTIONS_KEY in output_paths):
attributions = (
evaluation[attributions_key] | 'ConvertSliceAttributionsToProto' >>
beam.Map(convert_slice_attributions_to_proto))
file_path_prefix = output_paths[constants.ATTRIBUTIONS_KEY]
if output_file_format == _PARQUET_FORMAT:
_ = (
attributions
| 'ConvertAttributionsToParquetColumns' >>
beam.Map(convert_to_parquet_columns)
| 'WriteAttributionsToParquet' >> beam.io.WriteToParquet(
file_path_prefix=file_path_prefix,
schema=_SLICED_PARQUET_SCHEMA,
file_name_suffix='.' + output_file_format))
elif output_file_format == _TFRECORD_FORMAT:
_ = attributions | 'WriteAttributionsToTFRecord' >> beam.io.WriteToTFRecord(
file_path_prefix=file_path_prefix,
shard_name_template=None if output_file_format else '',
file_name_suffix=('.' +
output_file_format if output_file_format else ''),
coder=beam.coders.ProtoCoder(
metrics_for_slice_pb2.AttributionsForSlice))
else:
raise ValueError(f'Unsupported output file format: {output_file_format}.')
if (validations_key in evaluation and
constants.VALIDATIONS_KEY in output_paths):
validations = (
evaluation[validations_key]
| 'MergeValidationResults' >> beam.CombineGlobally(
CombineValidations(eval_config, rubber_stamp=rubber_stamp)))
file_path_prefix = output_paths[constants.VALIDATIONS_KEY]
# We only use a single shard here because validations are usually single
# values. Setting the shard_name_template to the empty string forces this.
shard_name_template = ''
if output_file_format == _PARQUET_FORMAT:
_ = (
validations
| 'ConvertValidationsToParquetColumns' >> beam.Map(
lambda v: # pylint: disable=g-long-lambda
{_SERIALIZED_VALUE_PARQUET_COLUMN_NAME: v.SerializeToString()})
| 'WriteValidationsToParquet' >> beam.io.WriteToParquet(
file_path_prefix=file_path_prefix,
shard_name_template=shard_name_template,
schema=_UNSLICED_PARQUET_SCHEMA,
file_name_suffix='.' + output_file_format))
elif output_file_format == _TFRECORD_FORMAT:
_ = (
validations
| 'WriteValidationsToTFRecord' >> beam.io.WriteToTFRecord(
file_path_prefix=file_path_prefix,
shard_name_template=shard_name_template,
file_name_suffix=('.' + output_file_format
if output_file_format else ''),
coder=beam.coders.ProtoCoder(
validation_result_pb2.ValidationResult)))
else:
raise ValueError(f'Unsupported output file format: {output_file_format}.')
return beam.pvalue.PDone(list(evaluation.values())[0].pipeline)