def _WriteMetricsPlotsAndValidations()

in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py [0:0]


def _WriteMetricsPlotsAndValidations(  # pylint: disable=invalid-name
    evaluation: evaluator.Evaluation,
    output_paths: Dict[str, str],
    eval_config: config_pb2.EvalConfig,
    add_metrics_callbacks: List[types.AddMetricsCallbackType],
    metrics_key: str,
    plots_key: str,
    attributions_key: str,
    validations_key: str,
    output_file_format: str,
    rubber_stamp: bool = False) -> beam.pvalue.PDone:
  """PTransform to write metrics and plots."""

  if output_file_format not in _SUPPORTED_FORMATS:
    raise ValueError('only "{}" formats are currently supported but got '
                     'output_file_format={}'.format(_SUPPORTED_FORMATS,
                                                    output_file_format))

  def convert_slice_key_to_parquet_dict(
      slice_key: metrics_for_slice_pb2.SliceKey) -> _SliceKeyDictPythonType:
    single_slice_key_dicts = []
    for single_slice_key in slice_key.single_slice_keys:
      kind = single_slice_key.WhichOneof('kind')
      if not kind:
        continue
      single_slice_key_dicts.append({kind: getattr(single_slice_key, kind)})
    return {_SINGLE_SLICE_KEYS_PARQUET_FIELD_NAME: single_slice_key_dicts}

  def convert_to_parquet_columns(
      value: Union[metrics_for_slice_pb2.MetricsForSlice,
                   metrics_for_slice_pb2.PlotsForSlice,
                   metrics_for_slice_pb2.AttributionsForSlice]
  ) -> Dict[str, Union[_SliceKeyDictPythonType, bytes]]:
    return {
        _SLICE_KEY_PARQUET_COLUMN_NAME:
            convert_slice_key_to_parquet_dict(value.slice_key),
        _SERIALIZED_VALUE_PARQUET_COLUMN_NAME:
            value.SerializeToString()
    }

  if metrics_key in evaluation and constants.METRICS_KEY in output_paths:
    metrics = (
        evaluation[metrics_key] | 'ConvertSliceMetricsToProto' >> beam.Map(
            convert_slice_metrics_to_proto,
            add_metrics_callbacks=add_metrics_callbacks))

    file_path_prefix = output_paths[constants.METRICS_KEY]
    if output_file_format == _PARQUET_FORMAT:
      _ = (
          metrics
          | 'ConvertToParquetColumns' >> beam.Map(convert_to_parquet_columns)
          | 'WriteMetricsToParquet' >> beam.io.WriteToParquet(
              file_path_prefix=file_path_prefix,
              schema=_SLICED_PARQUET_SCHEMA,
              file_name_suffix='.' + output_file_format))
    elif output_file_format == _TFRECORD_FORMAT:
      _ = metrics | 'WriteMetrics' >> beam.io.WriteToTFRecord(
          file_path_prefix=file_path_prefix,
          shard_name_template=None if output_file_format else '',
          file_name_suffix=('.' +
                            output_file_format if output_file_format else ''),
          coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.MetricsForSlice))
    else:
      raise ValueError(f'Unsupported output file format: {output_file_format}.')

  if plots_key in evaluation and constants.PLOTS_KEY in output_paths:
    plots = (
        evaluation[plots_key] | 'ConvertSlicePlotsToProto' >> beam.Map(
            convert_slice_plots_to_proto,
            add_metrics_callbacks=add_metrics_callbacks))

    file_path_prefix = output_paths[constants.PLOTS_KEY]
    if output_file_format == _PARQUET_FORMAT:
      _ = (
          plots
          |
          'ConvertPlotsToParquetColumns' >> beam.Map(convert_to_parquet_columns)
          | 'WritePlotsToParquet' >> beam.io.WriteToParquet(
              file_path_prefix=file_path_prefix,
              schema=_SLICED_PARQUET_SCHEMA,
              file_name_suffix='.' + output_file_format))
    elif output_file_format == _TFRECORD_FORMAT:
      _ = plots | 'WritePlotsToTFRecord' >> beam.io.WriteToTFRecord(
          file_path_prefix=file_path_prefix,
          shard_name_template=None if output_file_format else '',
          file_name_suffix=('.' +
                            output_file_format if output_file_format else ''),
          coder=beam.coders.ProtoCoder(metrics_for_slice_pb2.PlotsForSlice))
    else:
      raise ValueError(f'Unsupported output file format: {output_file_format}.')

  if (attributions_key in evaluation and
      constants.ATTRIBUTIONS_KEY in output_paths):
    attributions = (
        evaluation[attributions_key] | 'ConvertSliceAttributionsToProto' >>
        beam.Map(convert_slice_attributions_to_proto))

    file_path_prefix = output_paths[constants.ATTRIBUTIONS_KEY]
    if output_file_format == _PARQUET_FORMAT:
      _ = (
          attributions
          | 'ConvertAttributionsToParquetColumns' >>
          beam.Map(convert_to_parquet_columns)
          | 'WriteAttributionsToParquet' >> beam.io.WriteToParquet(
              file_path_prefix=file_path_prefix,
              schema=_SLICED_PARQUET_SCHEMA,
              file_name_suffix='.' + output_file_format))
    elif output_file_format == _TFRECORD_FORMAT:
      _ = attributions | 'WriteAttributionsToTFRecord' >> beam.io.WriteToTFRecord(
          file_path_prefix=file_path_prefix,
          shard_name_template=None if output_file_format else '',
          file_name_suffix=('.' +
                            output_file_format if output_file_format else ''),
          coder=beam.coders.ProtoCoder(
              metrics_for_slice_pb2.AttributionsForSlice))
    else:
      raise ValueError(f'Unsupported output file format: {output_file_format}.')

  if (validations_key in evaluation and
      constants.VALIDATIONS_KEY in output_paths):
    validations = (
        evaluation[validations_key]
        | 'MergeValidationResults' >> beam.CombineGlobally(
            CombineValidations(eval_config, rubber_stamp=rubber_stamp)))

    file_path_prefix = output_paths[constants.VALIDATIONS_KEY]
    # We only use a single shard here because validations are usually single
    # values. Setting the shard_name_template to the empty string forces this.
    shard_name_template = ''
    if output_file_format == _PARQUET_FORMAT:
      _ = (
          validations
          | 'ConvertValidationsToParquetColumns' >> beam.Map(
              lambda v:  # pylint: disable=g-long-lambda
              {_SERIALIZED_VALUE_PARQUET_COLUMN_NAME: v.SerializeToString()})
          | 'WriteValidationsToParquet' >> beam.io.WriteToParquet(
              file_path_prefix=file_path_prefix,
              shard_name_template=shard_name_template,
              schema=_UNSLICED_PARQUET_SCHEMA,
              file_name_suffix='.' + output_file_format))
    elif output_file_format == _TFRECORD_FORMAT:
      _ = (
          validations
          | 'WriteValidationsToTFRecord' >> beam.io.WriteToTFRecord(
              file_path_prefix=file_path_prefix,
              shard_name_template=shard_name_template,
              file_name_suffix=('.' + output_file_format
                                if output_file_format else ''),
              coder=beam.coders.ProtoCoder(
                  validation_result_pb2.ValidationResult)))
    else:
      raise ValueError(f'Unsupported output file format: {output_file_format}.')

  return beam.pvalue.PDone(list(evaluation.values())[0].pipeline)