in example_zoo/tensorflow/models/keras_cifar_main/official/resnet/resnet_run_loop.py [0:0]
def process_record_dataset(dataset,
is_training,
batch_size,
shuffle_buffer,
parse_record_fn,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
dataset: A Dataset representing raw records
is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup
time and use less memory.
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data.
Returns:
Dataset of (image, label) pairs ready for iteration.
"""
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset = dataset.prefetch(buffer_size=batch_size)
if is_training:
# Shuffles records before repeating to respect epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels.
dataset = dataset.apply(
tf.contrib.data.map_and_batch(
lambda value: parse_record_fn(value, is_training, dtype),
batch_size=batch_size,
num_parallel_batches=num_parallel_batches,
drop_remainder=False))
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
# background all of the above processing work and keep it out of the
# critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE
# allows DistributionStrategies to adjust how many batches to fetch based
# on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
# Defines a specific size thread pool for tf.data operations.
if datasets_num_private_threads:
tf.logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
datasets_num_private_threads,
display_name='input_pipeline_thread_pool'))
return dataset