def process_record_dataset()

in tensorflow_examples/profiling/imagenet_preprocessing_ineffecient_input_pipeline.py [0:0]


def process_record_dataset(dataset,
                           is_training,
                           batch_size,
                           shuffle_buffer,
                           parse_record_fn,
                           num_epochs=1,
                           dtype=tf.float32,
                           datasets_num_private_threads=None,
                           drop_remainder=False,
                           tf_data_experimental_slack=False):
  """Given a Dataset with raw records, return an iterator over the records.

  Args:
    dataset: A Dataset representing raw records
    is_training: A boolean denoting whether the input is for training.
    batch_size: The number of samples per batch.
    shuffle_buffer: The buffer size to use when shuffling records. A larger
      value results in better randomness, but smaller values reduce startup
      time and use less memory.
    parse_record_fn: A function that takes a raw record and returns the
      corresponding (image, label) pair.
    num_epochs: The number of epochs to repeat the dataset.
    dtype: Data type to use for images/features.
    datasets_num_private_threads: Number of threads for a private
      threadpool created for all datasets computation.
    drop_remainder: A boolean indicates whether to drop the remainder of the
      batches. If True, the batch dimension will be static.
    tf_data_experimental_slack: Whether to enable tf.data's
      `experimental_slack` option.

  Returns:
    Dataset of (image, label) pairs ready for iteration.
  """
  # Defines a specific size thread pool for tf.data operations.
  if datasets_num_private_threads:
    options = tf.data.Options()
    options.experimental_threading.private_threadpool_size = (
        datasets_num_private_threads)
    dataset = dataset.with_options(options)
    logging.info(
        'datasets_num_private_threads: %s', datasets_num_private_threads)

  if is_training:
    # Shuffles records before repeating to respect epoch boundaries.
    dataset = dataset.shuffle(buffer_size=shuffle_buffer)
    # Repeats the dataset for the number of epochs to train.
    dataset = dataset.repeat()

  # Parses the raw records into images and labels.

  # BEGIN_DEOPTIMIZE
  # Remove data autotuning
  # dataset = dataset.map(
  #    lambda value: parse_record_fn(value, is_training, dtype),
  #    num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # END_DEOPTIMIZE

  dataset = dataset.map(
      lambda value: parse_record_fn(value, is_training, dtype))
  dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

  # Operations between the final prefetch and the get_next call to the iterator
  # will happen synchronously during run time. We prefetch here again to
  # background all of the above processing work and keep it out of the
  # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
  # allows DistributionStrategies to adjust how many batches to fetch based
  # on how many devices are present.

  # BEGIN_DEOPTIMIZE
  # Remove the prefetch
  # dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  # END_DEOPTIMIZE

  options = tf.data.Options()
  options.experimental_slack = tf_data_experimental_slack
  dataset = dataset.with_options(options)

  return dataset