in tf-distribution-options/code/utilities.py [0:0]
def process_input(epochs, batch_size, channel, channel_name, data_config):
mode = data_config[channel_name]['TrainingInputMode']
filenames = _get_filenames(channel_name, channel)
# Repeat infinitely.
logging.info("Running {} in {} mode".format(channel_name, mode))
if mode == 'Pipe':
from sagemaker_tensorflow import PipeModeDataset
dataset = PipeModeDataset(channel=channel_name, record_format='TFRecord')
else:
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.repeat(epochs)
dataset = dataset.prefetch(10)
# Parse records.
dataset = dataset.map(
_dataset_parser, num_parallel_calls=10)
# Potentially shuffle records.
if channel_name == 'train':
# Ensure that the capacity is sufficiently large to provide good random
# shuffling.
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
dataset = dataset.shuffle(buffer_size=buffer_size)
# Batch it up.
dataset = dataset.batch(batch_size, drop_remainder=True)
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
return image_batch, label_batch