def load_dataset()

in source_directory/training/training_script.py [0:0]


def load_dataset(epochs, batch_size, channel_name):

    # load from with tf.data API
    if args.tfdata_s3uri:
        s3uri = '{}/{}'.format(args.tfdata_s3uri, channel_name)
        bucket = s3uri.split('/')[2]
        prefix = os.path.join(*s3uri.split('/')[3:])
        s3_client = boto3.client('s3')
        objects_list = s3_client.list_objects(Bucket=bucket, Prefix=prefix)
        files = []
        for obj in objects_list['Contents']:
            files.append('s3://{}/{}'.format(bucket, obj['Key']))
        if args.use_horovod:
            files, smallest_amount_samples = get_files_for_processor(files)
        print("Files to be read from {} set:".format(channel_name))
        for f in files:
            print(f)
    else:
        pass
        
    dataset = tf.data.TFRecordDataset(files)

    # compute number of batches per epoch
    if args.use_horovod:
        num_batches_per_epoch = math.floor(smallest_amount_samples/batch_size)
    else:
        num_samples = sum(1 for _ in dataset)
        print("{} set has {} samples.".format(channel_name, num_samples))
        num_batches_per_epoch = math.floor(num_samples/batch_size)

    # parse records
    dataset = dataset.map(_dataset_parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # shuffle records for training set
    if channel_name == 'train':
        dataset = dataset.shuffle(buffer_size=1000)
        dataset = dataset.map(image_augmentation, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # batch the dataset
    dataset = dataset.batch(batch_size, drop_remainder=True)

    # repeat and prefetch
    dataset = dataset.repeat(epochs)
    dataset = dataset.prefetch(1000)
    
    return dataset, num_batches_per_epoch