def assertAnalyzeAndTransformResults()

in tensorflow_transform/beam/tft_unit.py [0:0]


  def assertAnalyzeAndTransformResults(self,
                                       input_data,
                                       input_metadata,
                                       preprocessing_fn,
                                       expected_data=None,
                                       expected_metadata=None,
                                       expected_vocab_file_contents=None,
                                       test_data=None,
                                       desired_batch_size=None,
                                       beam_pipeline=None,
                                       temp_dir=None,
                                       force_tf_compat_v1=False,
                                       output_record_batches=False):
    """Assert that input data and metadata is transformed as expected.

    This methods asserts transformed data and transformed metadata match
    with expected_data and expected_metadata.

    Args:
      input_data: Input data formatted in one of two ways:
        * A sequence of dicts whose values are one of:
          strings, lists of strings, numeric types or a pair of those.
          Must have at least one key so that we can infer the batch size, or
        * A sequence of pa.RecordBatch.
      input_metadata: One of -
        * DatasetMetadata describing input_data if `input_data` are dicts.
        * TensorAdapterConfig otherwise.
      preprocessing_fn: A function taking a dict of tensors and returning
          a dict of tensors.
      expected_data: (optional) A dataset with the same type constraints as
          input_data, but representing the output after transformation.
          If supplied, transformed data is asserted to be equal.
      expected_metadata: (optional) DatasetMetadata describing the transformed
          data. If supplied, transformed metadata is asserted to be equal.
      expected_vocab_file_contents: (optional) A dictionary from vocab filenames
          to their expected content as a list of text lines or a list of tuples
          of frequency and text. Values should be the expected result of calling
          f.readlines() on the given asset files.
      test_data: (optional) If this is provided then instead of calling
          AnalyzeAndTransformDataset with input_data, this function will call
          AnalyzeDataset with input_data and TransformDataset with test_data.
          Note that this is the case even if input_data and test_data are equal.
          test_data should also conform to input_metadata.
      desired_batch_size: (optional) A batch size to batch elements by. If not
          provided, a batch size will be computed automatically.
      beam_pipeline: (optional) A Beam Pipeline to use in this test.
      temp_dir: If set, it is used as output directory, else a new unique
          directory is created.
      force_tf_compat_v1: A bool. If `True`, TFT's public APIs use Tensorflow
          in compat.v1 mode.
      output_record_batches: (optional) A bool. If `True`, `TransformDataset`
          and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es;
          otherwise, they output instance dicts.
    Raises:
      AssertionError: if the expected data does not match the results of
          transforming input_data according to preprocessing_fn, or
          (if provided) if the expected metadata does not match.
    """

    expected_vocab_file_contents = expected_vocab_file_contents or {}

    # Note: we don't separately test AnalyzeDataset and TransformDataset as
    # AnalyzeAndTransformDataset currently simply composes these two
    # transforms.  If in future versions of the code, the implementation
    # differs, we should also run AnalyzeDataset and TransformDataset composed.
    temp_dir = temp_dir or tempfile.mkdtemp(
        prefix=self._testMethodName, dir=self.get_temp_dir())
    with beam_pipeline or self._makeTestPipeline() as pipeline:
      with beam_impl.Context(
          temp_dir=temp_dir,
          desired_batch_size=desired_batch_size,
          force_tf_compat_v1=force_tf_compat_v1):
        input_data = pipeline | 'CreateInput' >> beam.Create(input_data,
                                                             reshuffle=False)
        if test_data is None:
          (transformed_data, transformed_metadata), transform_fn = (
              (input_data, input_metadata)
              | beam_impl.AnalyzeAndTransformDataset(
                  preprocessing_fn,
                  output_record_batches=output_record_batches))
        else:
          transform_fn = ((input_data, input_metadata)
                          | beam_impl.AnalyzeDataset(preprocessing_fn))
          test_data = pipeline | 'CreateTest' >> beam.Create(test_data)
          transformed_data, transformed_metadata = (
              ((test_data, input_metadata), transform_fn)
              | beam_impl.TransformDataset(
                  output_record_batches=output_record_batches))

        # Write transform_fn so we can test its assets
        _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir)

        transformed_data_path = os.path.join(temp_dir, 'transformed_data')
        if expected_data is not None:
          if isinstance(transformed_metadata,
                        beam_metadata_io.BeamDatasetMetadata):
            deferred_schema = (
                transformed_metadata.deferred_metadata
                | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
          else:
            deferred_schema = (
                self.pipeline | 'CreateDeferredSchema' >> beam.Create(
                    [transformed_metadata.schema]))

          if output_record_batches:
            # Since we are using a deferred schema, obtain a pcollection
            # containing the data coder that will be created from it.
            transformed_data_coder_pcol = (
                deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map(
                    example_coder.RecordBatchToExamplesEncoder))
            # Extract transformed RecordBatches and convert them to tf.Examples.
            encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMapTuple(
                lambda batch, _, data_coder: data_coder.encode(batch),
                data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
          else:
            # Since we are using a deferred schema, obtain a pcollection
            # containing the data coder that will be created from it.
            transformed_data_coder_pcol = (
                deferred_schema
                | 'ExampleProtoCoder' >> beam.Map(tft.coders.ExampleProtoCoder))
            encode_ptransform = 'EncodeExamples' >> beam.Map(
                lambda data, data_coder: data_coder.encode(data),
                data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))

          _ = (
              transformed_data
              | encode_ptransform
              | beam.io.tfrecordio.WriteToTFRecord(
                  transformed_data_path, shard_name_template=''))

    # TODO(ebreck) Log transformed_data somewhere.
    tf_transform_output = tft.TFTransformOutput(temp_dir)
    if expected_data is not None:
      examples = tf.compat.v1.python_io.tf_record_iterator(
          path=transformed_data_path)
      shapes = {
          f.name:
          [s.size for s in f.shape.dim] if f.HasField('shape') else [-1]
          for f in tf_transform_output.transformed_metadata.schema.feature
      }
      transformed_data = [
          _format_example_as_numpy_dict(e, shapes) for e in examples
      ]
      self.assertDataCloseOrEqual(expected_data, transformed_data)

    if expected_metadata:
      # Make a copy with no annotations.
      transformed_schema = schema_pb2.Schema()
      transformed_schema.CopyFrom(
          tf_transform_output.transformed_metadata.schema)
      transformed_schema.ClearField('annotation')
      for feature in transformed_schema.feature:
        feature.ClearField('annotation')

      # assertProtoEqual has a size limit on the length of the
      # serialized as text strings. Therefore, we first try to use
      # assertProtoEqual, if that fails we try to use assertEqual, if that fails
      # as well then we raise the exception from assertProtoEqual.
      try:
        compare.assertProtoEqual(self, expected_metadata.schema,
                                 transformed_schema)
      except AssertionError as compare_exception:
        try:
          self.assertEqual(expected_metadata.schema, transformed_schema)
        except AssertionError:
          raise compare_exception

    for filename, file_contents in expected_vocab_file_contents.items():
      full_filename = tf_transform_output.vocabulary_file_by_name(filename)
      self.AssertVocabularyContents(full_filename, file_contents)