in tensorflow_transform/beam/impl.py [0:0]
def expand(self, dataset_and_transform_fn):
"""Transforms the dataset using the transform_fn.
Args:
dataset_and_transform_fn: A tuple of dataset and preprocessing
function.
Returns:
A dataset transformed according to the transform_fn.
"""
(input_values, input_metadata), (transform_fn, output_metadata) = (
dataset_and_transform_fn)
if isinstance(input_metadata, dataset_metadata.DatasetMetadata):
if Context.get_passthrough_keys():
raise ValueError('passthrough_keys is set to {} but it is not '
'supported with instance dicts + DatasetMetadata '
'input. Follow the guide to switch to the TFXIO '
'format.'.format(Context.get_passthrough_keys()))
tf.compat.v1.logging.warning(
'You are passing instance dicts and DatasetMetadata to TFT which '
'will not provide optimal performance. Consider following the TFT '
'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).')
to_tfxio_ptransform = _InstanceDictInputToTFXIOInput(
input_metadata.schema, Context.get_desired_batch_size())
input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config()
input_values |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform
else:
input_tensor_adapter_config = input_metadata
# If exclude_outputs is set, update the output metadata.
if self._exclude_outputs is not None:
if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata):
new_metadata = _remove_columns_from_metadata(
output_metadata.dataset_metadata, self._exclude_outputs)
new_deferred_metadata = (
output_metadata.deferred_metadata
| 'RemoveColumms' >> beam.Map(_remove_columns_from_metadata,
self._exclude_outputs))
output_metadata = beam_metadata_io.BeamDatasetMetadata(
new_metadata, new_deferred_metadata, output_metadata.asset_map)
else:
output_metadata = _remove_columns_from_metadata(
output_metadata, self._exclude_outputs)
if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata):
deferred_schema = (
output_metadata.deferred_metadata
| 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
else:
deferred_schema = (
self.pipeline
| 'CreateDeferredSchema' >> beam.Create([output_metadata.schema]))
tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
type(self.pipeline.runner))
output_batches = (
input_values
| 'Transform' >> beam.ParDo(
_RunMetaGraphDoFn(
tf_config,
input_tensor_adapter_config=input_tensor_adapter_config,
use_tf_compat_v1=self._use_tf_compat_v1,
shared_graph_state_handle=shared.Shared(),
passthrough_keys=Context.get_passthrough_keys(),
exclude_outputs=self._exclude_outputs,
convert_passthrough_data=not self._output_record_batches),
saved_model_dir=beam.pvalue.AsSingleton(transform_fn)))
if self._output_record_batches:
# Since we are using a deferred schema, obtain a pcollection containing
# the converter that will be created from it.
converter_pcol = (
deferred_schema | 'MakeTensorToArrowConverter' >> beam.Map(
impl_helper.make_tensor_to_arrow_converter))
output_data = (
output_batches | 'ConvertToRecordBatch' >> beam.Map(
_convert_to_record_batch,
schema=beam.pvalue.AsSingleton(deferred_schema),
converter=beam.pvalue.AsSingleton(converter_pcol),
passthrough_keys=Context.get_passthrough_keys(),
input_metadata=input_metadata))
else:
output_data = (
output_batches | 'ConvertAndUnbatchToInstanceDicts' >> beam.FlatMap(
_convert_and_unbatch_to_instance_dicts,
schema=beam.pvalue.AsSingleton(deferred_schema),
passthrough_keys=Context.get_passthrough_keys()))
_clear_shared_state_after_barrier(self.pipeline, output_data)
return (output_data, output_metadata)