def expand()

in tensorflow_transform/beam/impl.py [0:0]


  def expand(self, dataset_and_transform_fn):
    """Transforms the dataset using the transform_fn.

    Args:
      dataset_and_transform_fn: A tuple of dataset and preprocessing
      function.

    Returns:
      A dataset transformed according to the transform_fn.
    """
    (input_values, input_metadata), (transform_fn, output_metadata) = (
        dataset_and_transform_fn)
    if isinstance(input_metadata, dataset_metadata.DatasetMetadata):
      if Context.get_passthrough_keys():
        raise ValueError('passthrough_keys is set to {} but it is not '
                         'supported with instance dicts + DatasetMetadata '
                         'input. Follow the guide to switch to the TFXIO '
                         'format.'.format(Context.get_passthrough_keys()))
      tf.compat.v1.logging.warning(
          'You are passing instance dicts and DatasetMetadata to TFT which '
          'will not provide optimal performance. Consider following the TFT '
          'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).')
      to_tfxio_ptransform = _InstanceDictInputToTFXIOInput(
          input_metadata.schema, Context.get_desired_batch_size())
      input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config()
      input_values |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform
    else:
      input_tensor_adapter_config = input_metadata

    # If exclude_outputs is set, update the output metadata.
    if self._exclude_outputs is not None:
      if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata):
        new_metadata = _remove_columns_from_metadata(
            output_metadata.dataset_metadata, self._exclude_outputs)
        new_deferred_metadata = (
            output_metadata.deferred_metadata
            | 'RemoveColumms' >> beam.Map(_remove_columns_from_metadata,
                                          self._exclude_outputs))
        output_metadata = beam_metadata_io.BeamDatasetMetadata(
            new_metadata, new_deferred_metadata, output_metadata.asset_map)
      else:
        output_metadata = _remove_columns_from_metadata(
            output_metadata, self._exclude_outputs)

    if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata):
      deferred_schema = (
          output_metadata.deferred_metadata
          | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema))
    else:
      deferred_schema = (
          self.pipeline
          | 'CreateDeferredSchema' >> beam.Create([output_metadata.schema]))

    tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get(
        type(self.pipeline.runner))
    output_batches = (
        input_values
        | 'Transform' >> beam.ParDo(
            _RunMetaGraphDoFn(
                tf_config,
                input_tensor_adapter_config=input_tensor_adapter_config,
                use_tf_compat_v1=self._use_tf_compat_v1,
                shared_graph_state_handle=shared.Shared(),
                passthrough_keys=Context.get_passthrough_keys(),
                exclude_outputs=self._exclude_outputs,
                convert_passthrough_data=not self._output_record_batches),
            saved_model_dir=beam.pvalue.AsSingleton(transform_fn)))
    if self._output_record_batches:
      # Since we are using a deferred schema, obtain a pcollection containing
      # the converter that will be created from it.
      converter_pcol = (
          deferred_schema | 'MakeTensorToArrowConverter' >> beam.Map(
              impl_helper.make_tensor_to_arrow_converter))
      output_data = (
          output_batches | 'ConvertToRecordBatch' >> beam.Map(
              _convert_to_record_batch,
              schema=beam.pvalue.AsSingleton(deferred_schema),
              converter=beam.pvalue.AsSingleton(converter_pcol),
              passthrough_keys=Context.get_passthrough_keys(),
              input_metadata=input_metadata))
    else:
      output_data = (
          output_batches | 'ConvertAndUnbatchToInstanceDicts' >> beam.FlatMap(
              _convert_and_unbatch_to_instance_dicts,
              schema=beam.pvalue.AsSingleton(deferred_schema),
              passthrough_keys=Context.get_passthrough_keys()))

    _clear_shared_state_after_barrier(self.pipeline, output_data)

    return (output_data, output_metadata)