petastorm/benchmark/cli.py (67 lines of code) (raw):
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This command line utility instantiates an instance of a Reader and measures its throughput. """
from __future__ import division
from __future__ import print_function
import argparse
import logging
import sys
from petastorm.benchmark.throughput import reader_throughput, \
WorkerPoolType, ReadMethod
logger = logging.getLogger(__name__)
def _parse_args(args):
# If min-after-dequeue value is not explicitly set from the command line, it will be calculated from the total
# shuffling queue size multiplied by this ratio
DEFAULT_MIN_AFTER_DEQUEUE_TO_QUEUE_SIZE_RATIO = 0.8
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('dataset_path', type=str, help='Path to a petastorm dataset')
parser.add_argument('--field-regex', type=str, nargs='+',
help='A list of regular expressions. Only fields that match one of the regex patterns will '
'be used during the benchmark.')
parser.add_argument('-w', '--workers-count', type=int, default=3,
help='Number of workers used by the reader')
parser.add_argument('-p', '--pool-type', type=WorkerPoolType, default=WorkerPoolType.THREAD,
choices=list(WorkerPoolType),
help='Type of a worker pool used by the reader')
parser.add_argument('-m', '--warmup-cycles', type=int, default=200,
help='Number of warmup read cycles. Warmup read cycles run before measurement cycles and '
'the throughput during these cycles is not accounted for in the reported results.')
parser.add_argument('-n', '--measure-cycles', type=int, default=1000,
help='Number cycles used for benchmark measurements. Measurements cycles are run after '
'warmup cycles.')
parser.add_argument('--profile-threads', dest='profile_threads', action='store_true',
help='Enables profiling threads. Will print result when thread pool is shut down.')
parser.add_argument('-d', '--read-method', type=ReadMethod, choices=list(ReadMethod),
default=ReadMethod.PYTHON,
help='Which read mode to use: \'python\': using python implementation. '
'\'tf\': constructing a small TF graph streaming data from pure python implementation.')
parser.add_argument('-q', '--shuffling-queue-size', type=int, default=500, required=False,
help='Size of the shuffling queue used to decorrelate row-group chunks. ')
parser.add_argument('--min-after-dequeue', type=int, default=None, required=False,
help='Minimum number of elements in a shuffling queue before entries can be read from it. '
'Default value is set to {}%% of the --shuffling-queue-size '
'parameter'.format(100 * DEFAULT_MIN_AFTER_DEQUEUE_TO_QUEUE_SIZE_RATIO))
parser.add_argument('--pyarrow-serialize', action='store_true', required=False,
help='When specified, faster pyarrow.serialize library is used. However, it does not support '
'all data types and implicitly converts some datatypes (e.g. int64->int32) which may'
'trigger errors when reading the data from Tensorflow.')
parser.add_argument('-vv', action='store_true', default=False, help='Sets logging level to DEBUG.')
parser.add_argument('-v', action='store_true', default=False, help='Sets logging level to INFO.')
args = parser.parse_args(args)
if not args.min_after_dequeue:
args.min_after_dequeue = DEFAULT_MIN_AFTER_DEQUEUE_TO_QUEUE_SIZE_RATIO * args.shuffling_queue_size
return args
def _main(args):
logging.basicConfig()
args = _parse_args(args)
if args.v:
logging.getLogger().setLevel(logging.INFO)
if args.vv:
logging.getLogger().setLevel(logging.DEBUG)
results = reader_throughput(args.dataset_path, args.field_regex, warmup_cycles_count=args.warmup_cycles,
measure_cycles_count=args.measure_cycles, pool_type=args.pool_type,
loaders_count=args.workers_count, profile_threads=args.profile_threads,
read_method=args.read_method, shuffling_queue_size=args.shuffling_queue_size,
min_after_dequeue=args.min_after_dequeue, pyarrow_serialize=args.pyarrow_serialize)
logger.info('Done')
print('Average sample read rate: {:1.2f} samples/sec; RAM {:1.2f} MB (rss); '
'CPU {:1.2f}%'.format(results.samples_per_second, results.memory_info.rss / 2 ** 20, results.cpu))
def main():
_main(sys.argv[1:])
if __name__ == '__main__':
_main(sys.argv[1:])