in tfx/components/transform/executor.py [0:0]
def expand(self, pipeline: beam.Pipeline):
# TODO(b/170304777): Remove this Create once the issue is fixed in beam.
# Forcing beam to treat this PTransform as non-primitive.
_ = pipeline | 'WorkaroundForBug170304777' >> beam.Create([None])
dataset_keys_list = [
dataset.dataset_key for dataset in self._analyze_data_list
]
cache_entry_keys = (
tft_beam.analysis_graph_builder.get_analysis_cache_entry_keys(
self._preprocessing_fn, self._feature_spec_or_typespec,
dataset_keys_list, self._force_tf_compat_v1))
# Without incremental cache, we estimate the number of stages in the
# pipeline to be roughly: analyzers * analysis_paths * 10.
# TODO(b/37788560): Remove this restriction when a greater number of
# stages can be handled efficiently.
estimated_stage_count = (
len(cache_entry_keys) * len(dataset_keys_list) * 10)
if (not self._enable_incremental_cache and
estimated_stage_count > _MAX_ESTIMATED_STAGES_COUNT):
logging.warning(
'Disabling cache because otherwise the number of stages might be '
'too high (%d analyzers, %d analysis paths)', len(cache_entry_keys),
len(dataset_keys_list))
# Returning None as the input cache here disables both input and output
# cache.
return ({d.dataset_key: d for d in self._analyze_data_list}, None,
estimated_stage_count)
cacheable_datasets = _GetCacheableDatasetsCount(
len(cache_entry_keys), self._is_stats_enabled)
if not self._enable_incremental_cache:
cacheable_datasets = len(dataset_keys_list)
if self._input_cache_dir is not None:
logging.info('Reading the following analysis cache entry keys: %s',
cache_entry_keys)
input_cache = (
pipeline
| 'ReadCache' >> analyzer_cache.ReadAnalysisCacheFromFS(
self._input_cache_dir,
dataset_keys_list,
source=self._cache_source,
cache_entry_keys=cache_entry_keys))
elif self._output_cache_dir is not None:
input_cache = {}
else:
# Using None here to indicate that this pipeline will not read or write
# cache.
input_cache = None
if input_cache is None:
# Cache is disabled so we won't be filtering out any datasets, and will
# always perform a flatten over all of them.
filtered_analysis_dataset_keys = dataset_keys_list
else:
filtered_analysis_dataset_keys = (
tft_beam.analysis_graph_builder.get_analysis_dataset_keys(
self._preprocessing_fn, self._feature_spec_or_typespec,
dataset_keys_list, input_cache, self._force_tf_compat_v1))
cached_datasets_count = 0
new_analyze_data_dict = {}
# Processing in reverse order assuming that later datasets are more likely
# to be processed again in a future iteration.
for dataset in self._analyze_data_list[::-1]:
if dataset.dataset_key in filtered_analysis_dataset_keys:
# Otherwise the caller will have to make sure the data is read.
if cached_datasets_count < cacheable_datasets:
new_analyze_data_dict[dataset.dataset_key] = dataset
cached_datasets_count += 1
else:
new_analyze_data_dict[dataset.dataset_key] = None
return (new_analyze_data_dict, input_cache, estimated_stage_count)