in tensorflow_transform/beam/tft_unit.py [0:0]
def assertAnalyzeAndTransformResults(self,
input_data,
input_metadata,
preprocessing_fn,
expected_data=None,
expected_metadata=None,
expected_vocab_file_contents=None,
test_data=None,
desired_batch_size=None,
beam_pipeline=None,
temp_dir=None,
force_tf_compat_v1=False,
output_record_batches=False):
"""Assert that input data and metadata is transformed as expected.
This methods asserts transformed data and transformed metadata match
with expected_data and expected_metadata.
Args:
input_data: Input data formatted in one of two ways:
* A sequence of dicts whose values are one of:
strings, lists of strings, numeric types or a pair of those.
Must have at least one key so that we can infer the batch size, or
* A sequence of pa.RecordBatch.
input_metadata: One of -
* DatasetMetadata describing input_data if `input_data` are dicts.
* TensorAdapterConfig otherwise.
preprocessing_fn: A function taking a dict of tensors and returning
a dict of tensors.
expected_data: (optional) A dataset with the same type constraints as
input_data, but representing the output after transformation.
If supplied, transformed data is asserted to be equal.
expected_metadata: (optional) DatasetMetadata describing the transformed
data. If supplied, transformed metadata is asserted to be equal.
expected_vocab_file_contents: (optional) A dictionary from vocab filenames
to their expected content as a list of text lines or a list of tuples
of frequency and text. Values should be the expected result of calling
f.readlines() on the given asset files.
test_data: (optional) If this is provided then instead of calling
AnalyzeAndTransformDataset with input_data, this function will call
AnalyzeDataset with input_data and TransformDataset with test_data.
Note that this is the case even if input_data and test_data are equal.
test_data should also conform to input_metadata.
desired_batch_size: (optional) A batch size to batch elements by. If not
provided, a batch size will be computed automatically.
beam_pipeline: (optional) A Beam Pipeline to use in this test.
temp_dir: If set, it is used as output directory, else a new unique
directory is created.
force_tf_compat_v1: A bool. If `True`, TFT's public APIs use Tensorflow
in compat.v1 mode.
output_record_batches: (optional) A bool. If `True`, `TransformDataset`
and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es;
otherwise, they output instance dicts.
Raises:
AssertionError: if the expected data does not match the results of
transforming input_data according to preprocessing_fn, or
(if provided) if the expected metadata does not match.
"""
expected_vocab_file_contents = expected_vocab_file_contents or {}
# Note: we don't separately test AnalyzeDataset and TransformDataset as
# AnalyzeAndTransformDataset currently simply composes these two
# transforms. If in future versions of the code, the implementation
# differs, we should also run AnalyzeDataset and TransformDataset composed.
temp_dir = temp_dir or tempfile.mkdtemp(
prefix=self._testMethodName, dir=self.get_temp_dir())
with beam_pipeline or self._makeTestPipeline() as pipeline:
with beam_impl.Context(
temp_dir=temp_dir,
desired_batch_size=desired_batch_size,
force_tf_compat_v1=force_tf_compat_v1):
input_data = pipeline | 'CreateInput' >> beam.Create(input_data,
reshuffle=False)
if test_data is None:
(transformed_data, transformed_metadata), transform_fn = (
(input_data, input_metadata)
| beam_impl.AnalyzeAndTransformDataset(
preprocessing_fn,
output_record_batches=output_record_batches))
else:
transform_fn = ((input_data, input_metadata)
| beam_impl.AnalyzeDataset(preprocessing_fn))
test_data = pipeline | 'CreateTest' >> beam.Create(test_data)
transformed_data, transformed_metadata = (
((test_data, input_metadata), transform_fn)
| beam_impl.TransformDataset(
output_record_batches=output_record_batches))
# Write transform_fn so we can test its assets
_ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir)
transformed_data_path = os.path.join(temp_dir, 'transformed_data')
if expected_data is not None:
if isinstance(transformed_metadata,
beam_metadata_io.BeamDatasetMetadata):
deferred_schema = (
transformed_metadata.deferred_metadata
| 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
else:
deferred_schema = (
self.pipeline | 'CreateDeferredSchema' >> beam.Create(
[transformed_metadata.schema]))
if output_record_batches:
# Since we are using a deferred schema, obtain a pcollection
# containing the data coder that will be created from it.
transformed_data_coder_pcol = (
deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map(
example_coder.RecordBatchToExamplesEncoder))
# Extract transformed RecordBatches and convert them to tf.Examples.
encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMapTuple(
lambda batch, _, data_coder: data_coder.encode(batch),
data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
else:
# Since we are using a deferred schema, obtain a pcollection
# containing the data coder that will be created from it.
transformed_data_coder_pcol = (
deferred_schema
| 'ExampleProtoCoder' >> beam.Map(tft.coders.ExampleProtoCoder))
encode_ptransform = 'EncodeExamples' >> beam.Map(
lambda data, data_coder: data_coder.encode(data),
data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol))
_ = (
transformed_data
| encode_ptransform
| beam.io.tfrecordio.WriteToTFRecord(
transformed_data_path, shard_name_template=''))
# TODO(ebreck) Log transformed_data somewhere.
tf_transform_output = tft.TFTransformOutput(temp_dir)
if expected_data is not None:
examples = tf.compat.v1.python_io.tf_record_iterator(
path=transformed_data_path)
shapes = {
f.name:
[s.size for s in f.shape.dim] if f.HasField('shape') else [-1]
for f in tf_transform_output.transformed_metadata.schema.feature
}
transformed_data = [
_format_example_as_numpy_dict(e, shapes) for e in examples
]
self.assertDataCloseOrEqual(expected_data, transformed_data)
if expected_metadata:
# Make a copy with no annotations.
transformed_schema = schema_pb2.Schema()
transformed_schema.CopyFrom(
tf_transform_output.transformed_metadata.schema)
transformed_schema.ClearField('annotation')
for feature in transformed_schema.feature:
feature.ClearField('annotation')
# assertProtoEqual has a size limit on the length of the
# serialized as text strings. Therefore, we first try to use
# assertProtoEqual, if that fails we try to use assertEqual, if that fails
# as well then we raise the exception from assertProtoEqual.
try:
compare.assertProtoEqual(self, expected_metadata.schema,
transformed_schema)
except AssertionError as compare_exception:
try:
self.assertEqual(expected_metadata.schema, transformed_schema)
except AssertionError:
raise compare_exception
for filename, file_contents in expected_vocab_file_contents.items():
full_filename = tf_transform_output.vocabulary_file_by_name(filename)
self.AssertVocabularyContents(full_filename, file_contents)