in tensorflow_transform/beam/impl.py [0:0]
def expand(self, dataset):
"""Analyze the dataset.
Args:
dataset: A dataset.
Returns:
A TransformFn containing the deferred transform function.
Raises:
ValueError: If preprocessing_fn has no outputs.
"""
(flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
input_metadata) = dataset
input_values_pcoll_dict = input_values_pcoll_dict or dict()
if isinstance(input_metadata, dataset_metadata.DatasetMetadata):
if Context.get_passthrough_keys():
raise ValueError('passthrough_keys is set to {} but it is not supported'
'with instance dicts + DatasetMetadata input. Follow '
'the guide to switch to the TFXIO format.'.format(
Context.get_passthrough_keys()))
tf.compat.v1.logging.warning(
'You are passing instance dicts and DatasetMetadata to TFT which '
'will not provide optimal performance. Consider following the TFT '
'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).')
to_tfxio_ptransform = _InstanceDictInputToTFXIOInput(
input_metadata.schema, Context.get_desired_batch_size())
input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config()
if flattened_pcoll is not None:
flattened_pcoll |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform
for key in input_values_pcoll_dict.keys():
if input_values_pcoll_dict[key] is not None:
input_values_pcoll_dict[key] |= (
'InstanceDictToRecordBatch[{}]'.format(key) >>
to_tfxio_ptransform)
else:
input_tensor_adapter_config = input_metadata
specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs()
if not specs:
raise ValueError('The input metadata is empty.')
base_temp_dir = Context.create_base_temp_dir()
# TODO(b/149997088): Do not pass base_temp_dir here as this graph does not
# need to be serialized to SavedModel.
graph, structured_inputs, structured_outputs = (
impl_helper.trace_preprocessing_function(self._preprocessing_fn, specs,
self._use_tf_compat_v1,
base_temp_dir))
# At this point we check that the preprocessing_fn has at least one
# output. This is because if we allowed the output of preprocessing_fn to
# be empty, we wouldn't be able to determine how many instances to
# "unbatch" the output into.
if not structured_outputs:
raise ValueError('The preprocessing function returned an empty dict')
if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
raise ValueError(
'The preprocessing function contained trainable variables '
'{}'.format(
graph.get_collection_ref(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))
pipeline = self.pipeline or (flattened_pcoll or next(
v for v in input_values_pcoll_dict.values() if v is not None)).pipeline
# Add a stage that inspects graph collections for API use counts and logs
# them as a beam metric.
_ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(
graph, Context._get_force_tf_compat_v1(), self._use_tf_compat_v1)) # pylint: disable=protected-access
asset_map = annotators.get_asset_annotations(graph)
# TF.HUB can error when unapproved collections are present. So we explicitly
# clear out the collections in the graph.
annotators.clear_asset_annotations(graph)
analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
graph, structured_inputs) if not self._use_tf_compat_v1 else None
tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
type(pipeline.runner))
extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs(
base_temp_dir=base_temp_dir,
tf_config=tf_config,
pipeline=pipeline,
flat_pcollection=flattened_pcoll,
pcollection_dict=input_values_pcoll_dict,
graph=graph,
input_signature=structured_inputs,
input_specs=specs,
input_tensor_adapter_config=input_tensor_adapter_config,
use_tf_compat_v1=self._use_tf_compat_v1,
cache_pcoll_dict=dataset_cache_dict,
preprocessing_fn=self._preprocessing_fn,
analyzers_fingerprint=analyzers_fingerprint)
transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
graph,
structured_inputs,
structured_outputs,
input_values_pcoll_dict.keys(),
cache_dict=dataset_cache_dict)
traverser = nodes.Traverser(
beam_common.ConstructBeamPipelineVisitor(extra_args))
transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)
if cache_value_nodes is not None:
output_cache_pcoll_dict = {}
for (dataset_key, cache_key), value_node in cache_value_nodes.items():
if dataset_key not in output_cache_pcoll_dict:
output_cache_pcoll_dict[dataset_key] = {}
output_cache_pcoll_dict[dataset_key][cache_key] = (
traverser.visit_value_node(value_node))
else:
output_cache_pcoll_dict = None
# Infer metadata. We take the inferred metadata and apply overrides that
# refer to values of tensors in the graph. The override tensors must
# be "constant" in that they don't depend on input data. The tensors can
# depend on analyzer outputs though. This allows us to set metadata that
# depends on analyzer outputs. _infer_metadata_from_saved_model will use the
# analyzer outputs stored in `transform_fn` to compute the metadata in a
# deferred manner, once the analyzer outputs are known.
if self._use_tf_compat_v1:
schema = schema_inference.infer_feature_schema(structured_outputs, graph)
else:
# Use metadata_fn here as func_graph outputs may be wrapped in an identity
# op and hence may not return the same tensors that were annotated.
tf_graph_context = graph_context.TFGraphContext(
module_to_export=tf.Module(),
temp_dir=base_temp_dir,
evaluated_replacements={})
concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
preprocessing_fn=self._preprocessing_fn,
structured_inputs=structured_inputs,
tf_graph_context=tf_graph_context,
evaluate_schema_overrides=False)
schema = schema_inference.infer_feature_schema_v2(
structured_outputs,
concrete_metadata_fn,
evaluate_schema_overrides=False)
deferred_metadata = (
transform_fn_pcoll
| 'ComputeDeferredMetadata[compat_v1={}]'.format(self._use_tf_compat_v1)
>> beam.Map(_infer_metadata_from_saved_model, self._use_tf_compat_v1))
full_metadata = beam_metadata_io.BeamDatasetMetadata(
dataset_metadata.DatasetMetadata(schema=schema), deferred_metadata,
asset_map)
_clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)
return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict