def expand()

in tfx/components/transform/executor.py [0:0]


    def expand(self, pipeline: beam.Pipeline):
      # TODO(b/170304777): Remove this Create once the issue is fixed in beam.
      # Forcing beam to treat this PTransform as non-primitive.
      _ = pipeline | 'WorkaroundForBug170304777' >> beam.Create([None])

      dataset_keys_list = [
          dataset.dataset_key for dataset in self._analyze_data_list
      ]
      cache_entry_keys = (
          tft_beam.analysis_graph_builder.get_analysis_cache_entry_keys(
              self._preprocessing_fn, self._feature_spec_or_typespec,
              dataset_keys_list, self._force_tf_compat_v1))
      # Without incremental cache, we estimate the number of stages in the
      # pipeline to be roughly: analyzers * analysis_paths * 10.
      # TODO(b/37788560): Remove this restriction when a greater number of
      # stages can be handled efficiently.
      estimated_stage_count = (
          len(cache_entry_keys) * len(dataset_keys_list) * 10)
      if (not self._enable_incremental_cache and
          estimated_stage_count > _MAX_ESTIMATED_STAGES_COUNT):
        logging.warning(
            'Disabling cache because otherwise the number of stages might be '
            'too high (%d analyzers, %d analysis paths)', len(cache_entry_keys),
            len(dataset_keys_list))
        # Returning None as the input cache here disables both input and output
        # cache.
        return ({d.dataset_key: d for d in self._analyze_data_list}, None,
                estimated_stage_count)

      cacheable_datasets = _GetCacheableDatasetsCount(
          len(cache_entry_keys), self._is_stats_enabled)
      if not self._enable_incremental_cache:
        cacheable_datasets = len(dataset_keys_list)

      if self._input_cache_dir is not None:
        logging.info('Reading the following analysis cache entry keys: %s',
                     cache_entry_keys)
        input_cache = (
            pipeline
            | 'ReadCache' >> analyzer_cache.ReadAnalysisCacheFromFS(
                self._input_cache_dir,
                dataset_keys_list,
                source=self._cache_source,
                cache_entry_keys=cache_entry_keys))
      elif self._output_cache_dir is not None:
        input_cache = {}
      else:
        # Using None here to indicate that this pipeline will not read or write
        # cache.
        input_cache = None

      if input_cache is None:
        # Cache is disabled so we won't be filtering out any datasets, and will
        # always perform a flatten over all of them.
        filtered_analysis_dataset_keys = dataset_keys_list
      else:
        filtered_analysis_dataset_keys = (
            tft_beam.analysis_graph_builder.get_analysis_dataset_keys(
                self._preprocessing_fn, self._feature_spec_or_typespec,
                dataset_keys_list, input_cache, self._force_tf_compat_v1))

      cached_datasets_count = 0
      new_analyze_data_dict = {}
      # Processing in reverse order assuming that later datasets are more likely
      # to be processed again in a future iteration.
      for dataset in self._analyze_data_list[::-1]:
        if dataset.dataset_key in filtered_analysis_dataset_keys:
          # Otherwise the caller will have to make sure the data is read.
          if cached_datasets_count < cacheable_datasets:
            new_analyze_data_dict[dataset.dataset_key] = dataset
            cached_datasets_count += 1
        else:
          new_analyze_data_dict[dataset.dataset_key] = None

      return (new_analyze_data_dict, input_cache, estimated_stage_count)