in src/train.py [0:0]
def _input(epochs, batch_size, channel, channel_name, hvd=None):
if hvd != None:
channel_name = '{}_{}'.format(channel_name, hvd.rank() % 4)
print("The channel name is ", channel_name)
channel_input_dir = args.training_env['channel_input_dirs'][channel_name]
print("The corresponding input directory is ", channel_input_dir)
mode = args.data_config[channel_name]['TrainingInputMode']
if mode == 'Pipe':
from sagemaker_tensorflow import PipeModeDataset
dataset = PipeModeDataset(channel=channel_name, record_format='TFRecord')
else:
filenames = get_filenames(channel_input_dir, hvd)
print("The correpsonding filenames are", filenames)
dataset = tf.data.TFRecordDataset(filenames)
if 'test' in channel_name:
dataset = dataset.map(_dataset_parser_with_slide)
else:
dataset = dataset.repeat(epochs)
dataset = dataset.map(_dataset_parser)
if 'train' in channel_name:
# Ensure that the capacity is sufficiently large to provide good random shuffling.
buffer_size = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 0.4) + 3 * batch_size
dataset = dataset.shuffle(buffer_size=buffer_size)
# Batch it up (only for train and valid)
if 'test' not in channel_name:
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(10)
return dataset