in source_directory/training/training_script.py [0:0]
def load_dataset(epochs, batch_size, channel_name):
# load from with tf.data API
if args.tfdata_s3uri:
s3uri = '{}/{}'.format(args.tfdata_s3uri, channel_name)
bucket = s3uri.split('/')[2]
prefix = os.path.join(*s3uri.split('/')[3:])
s3_client = boto3.client('s3')
objects_list = s3_client.list_objects(Bucket=bucket, Prefix=prefix)
files = []
for obj in objects_list['Contents']:
files.append('s3://{}/{}'.format(bucket, obj['Key']))
if args.use_horovod:
files, smallest_amount_samples = get_files_for_processor(files)
print("Files to be read from {} set:".format(channel_name))
for f in files:
print(f)
else:
pass
dataset = tf.data.TFRecordDataset(files)
# compute number of batches per epoch
if args.use_horovod:
num_batches_per_epoch = math.floor(smallest_amount_samples/batch_size)
else:
num_samples = sum(1 for _ in dataset)
print("{} set has {} samples.".format(channel_name, num_samples))
num_batches_per_epoch = math.floor(num_samples/batch_size)
# parse records
dataset = dataset.map(_dataset_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# shuffle records for training set
if channel_name == 'train':
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(image_augmentation, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# batch the dataset
dataset = dataset.batch(batch_size, drop_remainder=True)
# repeat and prefetch
dataset = dataset.repeat(epochs)
dataset = dataset.prefetch(1000)
return dataset, num_batches_per_epoch