def expand()

in tensorflow_transform/beam/impl.py [0:0]


  def expand(self, dataset):
    """Analyze the dataset.

    Args:
      dataset: A dataset.

    Returns:
      A TransformFn containing the deferred transform function.

    Raises:
      ValueError: If preprocessing_fn has no outputs.
    """
    (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict,
     input_metadata) = dataset
    input_values_pcoll_dict = input_values_pcoll_dict or dict()

    if isinstance(input_metadata, dataset_metadata.DatasetMetadata):
      if Context.get_passthrough_keys():
        raise ValueError('passthrough_keys is set to {} but it is not supported'
                         'with instance dicts + DatasetMetadata input. Follow '
                         'the guide to switch to the TFXIO format.'.format(
                             Context.get_passthrough_keys()))
      tf.compat.v1.logging.warning(
          'You are passing instance dicts and DatasetMetadata to TFT which '
          'will not provide optimal performance. Consider following the TFT '
          'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).')
      to_tfxio_ptransform = _InstanceDictInputToTFXIOInput(
          input_metadata.schema, Context.get_desired_batch_size())
      input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config()
      if flattened_pcoll is not None:
        flattened_pcoll |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform
      for key in input_values_pcoll_dict.keys():
        if input_values_pcoll_dict[key] is not None:
          input_values_pcoll_dict[key] |= (
              'InstanceDictToRecordBatch[{}]'.format(key) >>
              to_tfxio_ptransform)
    else:
      input_tensor_adapter_config = input_metadata

    specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs()

    if not specs:
      raise ValueError('The input metadata is empty.')

    base_temp_dir = Context.create_base_temp_dir()
    # TODO(b/149997088): Do not pass base_temp_dir here as this graph does not
    # need to be serialized to SavedModel.
    graph, structured_inputs, structured_outputs = (
        impl_helper.trace_preprocessing_function(self._preprocessing_fn, specs,
                                                 self._use_tf_compat_v1,
                                                 base_temp_dir))

    # At this point we check that the preprocessing_fn has at least one
    # output. This is because if we allowed the output of preprocessing_fn to
    # be empty, we wouldn't be able to determine how many instances to
    # "unbatch" the output into.
    if not structured_outputs:
      raise ValueError('The preprocessing function returned an empty dict')

    if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES):
      raise ValueError(
          'The preprocessing function contained trainable variables '
          '{}'.format(
              graph.get_collection_ref(
                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)))

    pipeline = self.pipeline or (flattened_pcoll or next(
        v for v in input_values_pcoll_dict.values() if v is not None)).pipeline

    # Add a stage that inspects graph collections for API use counts and logs
    # them as a beam metric.
    _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI(
        graph, Context._get_force_tf_compat_v1(), self._use_tf_compat_v1))  # pylint: disable=protected-access

    asset_map = annotators.get_asset_annotations(graph)
    # TF.HUB can error when unapproved collections are present. So we explicitly
    # clear out the collections in the graph.
    annotators.clear_asset_annotations(graph)

    analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
        graph, structured_inputs) if not self._use_tf_compat_v1 else None

    tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
        type(pipeline.runner))
    extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs(
        base_temp_dir=base_temp_dir,
        tf_config=tf_config,
        pipeline=pipeline,
        flat_pcollection=flattened_pcoll,
        pcollection_dict=input_values_pcoll_dict,
        graph=graph,
        input_signature=structured_inputs,
        input_specs=specs,
        input_tensor_adapter_config=input_tensor_adapter_config,
        use_tf_compat_v1=self._use_tf_compat_v1,
        cache_pcoll_dict=dataset_cache_dict,
        preprocessing_fn=self._preprocessing_fn,
        analyzers_fingerprint=analyzers_fingerprint)

    transform_fn_future, cache_value_nodes = analysis_graph_builder.build(
        graph,
        structured_inputs,
        structured_outputs,
        input_values_pcoll_dict.keys(),
        cache_dict=dataset_cache_dict)
    traverser = nodes.Traverser(
        beam_common.ConstructBeamPipelineVisitor(extra_args))
    transform_fn_pcoll = traverser.visit_value_node(transform_fn_future)

    if cache_value_nodes is not None:
      output_cache_pcoll_dict = {}
      for (dataset_key, cache_key), value_node in cache_value_nodes.items():
        if dataset_key not in output_cache_pcoll_dict:
          output_cache_pcoll_dict[dataset_key] = {}
        output_cache_pcoll_dict[dataset_key][cache_key] = (
            traverser.visit_value_node(value_node))
    else:
      output_cache_pcoll_dict = None

    # Infer metadata.  We take the inferred metadata and apply overrides that
    # refer to values of tensors in the graph.  The override tensors must
    # be "constant" in that they don't depend on input data.  The tensors can
    # depend on analyzer outputs though.  This allows us to set metadata that
    # depends on analyzer outputs. _infer_metadata_from_saved_model will use the
    # analyzer outputs stored in `transform_fn` to compute the metadata in a
    # deferred manner, once the analyzer outputs are known.
    if self._use_tf_compat_v1:
      schema = schema_inference.infer_feature_schema(structured_outputs, graph)
    else:
      # Use metadata_fn here as func_graph outputs may be wrapped in an identity
      # op and hence may not return the same tensors that were annotated.
      tf_graph_context = graph_context.TFGraphContext(
          module_to_export=tf.Module(),
          temp_dir=base_temp_dir,
          evaluated_replacements={})
      concrete_metadata_fn = schema_inference.get_traced_metadata_fn(
          preprocessing_fn=self._preprocessing_fn,
          structured_inputs=structured_inputs,
          tf_graph_context=tf_graph_context,
          evaluate_schema_overrides=False)
      schema = schema_inference.infer_feature_schema_v2(
          structured_outputs,
          concrete_metadata_fn,
          evaluate_schema_overrides=False)
    deferred_metadata = (
        transform_fn_pcoll
        | 'ComputeDeferredMetadata[compat_v1={}]'.format(self._use_tf_compat_v1)
        >> beam.Map(_infer_metadata_from_saved_model, self._use_tf_compat_v1))

    full_metadata = beam_metadata_io.BeamDatasetMetadata(
        dataset_metadata.DatasetMetadata(schema=schema), deferred_metadata,
        asset_map)

    _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll)

    return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict