def assertAnalyzerOutputs()

in tensorflow_transform/beam/tft_unit.py [0:0]


  def assertAnalyzerOutputs(self,
                            input_data,
                            input_metadata,
                            analyzer_fn,
                            expected_outputs,
                            test_data=None,
                            desired_batch_size=None,
                            beam_pipeline=None,
                            force_tf_compat_v1=False,
                            output_record_batches=False):
    """Assert that input data and metadata is transformed as expected.

    This methods asserts transformed data and transformed metadata match
    with expected_data and expected_metadata.

    Args:
      input_data: Input data formatted in one of two ways:
        * A sequence of dicts whose values are one of:
          strings, lists of strings, numeric types or a pair of those.
          Must have at least one key so that we can infer the batch size, or
        * A sequence of pa.RecordBatch.
      input_metadata: One of -
        * DatasetMetadata describing input_data if `input_data` are dicts.
        * TensorAdapterConfig otherwise.
      analyzer_fn: A function taking a dict of tensors and returning a dict of
        tensors.  Unlike a preprocessing_fn, this should emit the results of a
        call to an analyzer, while a preprocessing_fn must typically add a batch
        dimension and broadcast across this batch dimension.
      expected_outputs: A dict whose keys are the same as those of the output of
        `analyzer_fn` and whose values are convertible to an ndarrays.
      test_data: (optional) If this is provided then instead of calling
        AnalyzeAndTransformDataset with input_data, this function will call
        AnalyzeDataset with input_data and TransformDataset with test_data.
        Must be provided if the input_data is empty. test_data should also
        conform to input_metadata.
      desired_batch_size: (Optional) A batch size to batch elements by. If not
        provided, a batch size will be computed automatically.
      beam_pipeline: (optional) A Beam Pipeline to use in this test.
      force_tf_compat_v1: A bool. If `True`, TFT's public APIs use
        Tensorflow in compat.v1 mode.
      output_record_batches: (optional) A bool. If `True`, `TransformDataset`
        and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es;
        otherwise, they output instance dicts.

    Raises:
      AssertionError: If the expected output does not match the results of
          the analyzer_fn.
    """

    def preprocessing_fn(inputs):
      """A helper function for validating analyzer outputs."""
      # Get tensors representing the outputs of the analyzers
      analyzer_outputs = analyzer_fn(inputs)

      # Check that keys of analyzer_outputs match expected_output.
      self.assertCountEqual(analyzer_outputs.keys(), expected_outputs.keys())

      # Get batch size from any input tensor.
      an_input = next(iter(inputs.values()))
      if isinstance(an_input, tf.RaggedTensor):
        batch_size = an_input.bounding_shape(axis=0)
      else:
        batch_size = tf.shape(input=an_input)[0]

      # Add a batch dimension and broadcast the analyzer outputs.
      result = {}
      for key, output_tensor in analyzer_outputs.items():
        # Get the expected shape, and set it.
        output_shape = list(expected_outputs[key].shape)
        try:
          output_tensor.set_shape(output_shape)
        except ValueError as e:
          raise ValueError('Error for key {}: {}'.format(key, str(e)))
        # Add a batch dimension
        output_tensor = tf.expand_dims(output_tensor, 0)
        # Broadcast along the batch dimension
        result[key] = tf.tile(
            output_tensor, multiples=[batch_size] + [1] * len(output_shape))

      return result

    if input_data and not test_data:
      # Create test dataset by repeating the first instance a number of times.
      num_test_instances = 3
      test_data = [input_data[0]] * num_test_instances
      expected_data = [expected_outputs] * num_test_instances
    else:
      # Ensure that the test dataset is specified and is not empty.
      assert test_data
      expected_data = [expected_outputs] * len(test_data)
    expected_metadata = metadata_from_feature_spec({
        key: tf.io.FixedLenFeature(value.shape, tf.as_dtype(value.dtype))
        for key, value in expected_outputs.items()
    })

    self.assertAnalyzeAndTransformResults(
        input_data,
        input_metadata,
        preprocessing_fn,
        expected_data,
        expected_metadata,
        test_data=test_data,
        desired_batch_size=desired_batch_size,
        beam_pipeline=beam_pipeline,
        force_tf_compat_v1=force_tf_compat_v1,
        output_record_batches=output_record_batches)