def input_fn()

in tensorflow_examples/profiling/imagenet_preprocessing_ineffecient_input_pipeline.py [0:0]


def input_fn(is_training,
             data_dir,
             batch_size,
             num_epochs=1,
             dtype=tf.float32,
             datasets_num_private_threads=None,
             parse_record_fn=parse_record,
             input_context=None,
             drop_remainder=False,
             tf_data_experimental_slack=False,
             training_dataset_cache=False,
             filenames=None):
  """Input function which provides batches for train or eval.

  Args:
    is_training: A boolean denoting whether the input is for training.
    data_dir: The directory containing the input data.
    batch_size: The number of samples per batch.
    num_epochs: The number of epochs to repeat the dataset.
    dtype: Data type to use for images/features
    datasets_num_private_threads: Number of private threads for tf.data.
    parse_record_fn: Function to use for parsing the records.
    input_context: A `tf.distribute.InputContext` object passed in by
      `tf.distribute.Strategy`.
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
    tf_data_experimental_slack: Whether to enable tf.data's
      `experimental_slack` option.
    training_dataset_cache: Whether to cache the training dataset on workers.
       Typically used to improve training performance when training data is in
       remote storage and can fit into worker memory.
    filenames: Optional field for providing the file names of the TFRecords.

  Returns:
    A dataset that can be used for iteration.
  """
  if filenames is None:
    filenames = get_filenames(is_training, data_dir)
  dataset = tf.data.Dataset.from_tensor_slices(filenames)

  if input_context:
    logging.info(
        'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
        input_context.input_pipeline_id, input_context.num_input_pipelines)
    dataset = dataset.shard(input_context.num_input_pipelines,
                            input_context.input_pipeline_id)

  if is_training:
    # Shuffle the input files
    dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)

  # Convert to individual records.
  # cycle_length = 10 means that up to 10 files will be read and deserialized in
  # parallel. You may want to increase this number if you have a large number of
  # CPU cores.

  # BEGIN_DEOPTIMIZE
  # Deoptimization by removing cycle length and data autotining
  # dataset = dataset.interleave(
  #    tf.data.TFRecordDataset,
  #    cycle_length=10,
  #    num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # END_DEOPTIMIZE

  dataset = dataset.interleave(tf.data.TFRecordDataset)

  if is_training and training_dataset_cache:
    # Improve training performance when training data is in remote storage and
    # can fit into worker memory.
    dataset = dataset.cache()

  return process_record_dataset(
      dataset=dataset,
      is_training=is_training,
      batch_size=batch_size,
      shuffle_buffer=_SHUFFLE_BUFFER,
      parse_record_fn=parse_record_fn,
      num_epochs=num_epochs,
      dtype=dtype,
      datasets_num_private_threads=datasets_num_private_threads,
      drop_remainder=drop_remainder,
      tf_data_experimental_slack=tf_data_experimental_slack,
  )