in tensorflow_transform/beam/tft_unit.py [0:0]
def assertAnalyzerOutputs(self,
input_data,
input_metadata,
analyzer_fn,
expected_outputs,
test_data=None,
desired_batch_size=None,
beam_pipeline=None,
force_tf_compat_v1=False,
output_record_batches=False):
"""Assert that input data and metadata is transformed as expected.
This methods asserts transformed data and transformed metadata match
with expected_data and expected_metadata.
Args:
input_data: Input data formatted in one of two ways:
* A sequence of dicts whose values are one of:
strings, lists of strings, numeric types or a pair of those.
Must have at least one key so that we can infer the batch size, or
* A sequence of pa.RecordBatch.
input_metadata: One of -
* DatasetMetadata describing input_data if `input_data` are dicts.
* TensorAdapterConfig otherwise.
analyzer_fn: A function taking a dict of tensors and returning a dict of
tensors. Unlike a preprocessing_fn, this should emit the results of a
call to an analyzer, while a preprocessing_fn must typically add a batch
dimension and broadcast across this batch dimension.
expected_outputs: A dict whose keys are the same as those of the output of
`analyzer_fn` and whose values are convertible to an ndarrays.
test_data: (optional) If this is provided then instead of calling
AnalyzeAndTransformDataset with input_data, this function will call
AnalyzeDataset with input_data and TransformDataset with test_data.
Must be provided if the input_data is empty. test_data should also
conform to input_metadata.
desired_batch_size: (Optional) A batch size to batch elements by. If not
provided, a batch size will be computed automatically.
beam_pipeline: (optional) A Beam Pipeline to use in this test.
force_tf_compat_v1: A bool. If `True`, TFT's public APIs use
Tensorflow in compat.v1 mode.
output_record_batches: (optional) A bool. If `True`, `TransformDataset`
and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es;
otherwise, they output instance dicts.
Raises:
AssertionError: If the expected output does not match the results of
the analyzer_fn.
"""
def preprocessing_fn(inputs):
"""A helper function for validating analyzer outputs."""
# Get tensors representing the outputs of the analyzers
analyzer_outputs = analyzer_fn(inputs)
# Check that keys of analyzer_outputs match expected_output.
self.assertCountEqual(analyzer_outputs.keys(), expected_outputs.keys())
# Get batch size from any input tensor.
an_input = next(iter(inputs.values()))
if isinstance(an_input, tf.RaggedTensor):
batch_size = an_input.bounding_shape(axis=0)
else:
batch_size = tf.shape(input=an_input)[0]
# Add a batch dimension and broadcast the analyzer outputs.
result = {}
for key, output_tensor in analyzer_outputs.items():
# Get the expected shape, and set it.
output_shape = list(expected_outputs[key].shape)
try:
output_tensor.set_shape(output_shape)
except ValueError as e:
raise ValueError('Error for key {}: {}'.format(key, str(e)))
# Add a batch dimension
output_tensor = tf.expand_dims(output_tensor, 0)
# Broadcast along the batch dimension
result[key] = tf.tile(
output_tensor, multiples=[batch_size] + [1] * len(output_shape))
return result
if input_data and not test_data:
# Create test dataset by repeating the first instance a number of times.
num_test_instances = 3
test_data = [input_data[0]] * num_test_instances
expected_data = [expected_outputs] * num_test_instances
else:
# Ensure that the test dataset is specified and is not empty.
assert test_data
expected_data = [expected_outputs] * len(test_data)
expected_metadata = metadata_from_feature_spec({
key: tf.io.FixedLenFeature(value.shape, tf.as_dtype(value.dtype))
for key, value in expected_outputs.items()
})
self.assertAnalyzeAndTransformResults(
input_data,
input_metadata,
preprocessing_fn,
expected_data,
expected_metadata,
test_data=test_data,
desired_batch_size=desired_batch_size,
beam_pipeline=beam_pipeline,
force_tf_compat_v1=force_tf_compat_v1,
output_record_batches=output_record_batches)