def benchmark()

in benchmark.py [0:0]


def benchmark(args):
    if args.amp:
        _logger.warning("Overriding precision to 'amp' since --amp flag set.")
        args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
    _logger.info(f'Benchmarking in {args.precision} precision. '
                 f'{"NHWC" if args.channels_last else "NCHW"} layout. '
                 f'torchscript {"enabled" if args.torchscript else "disabled"}')

    bench_kwargs = vars(args).copy()
    bench_kwargs.pop('amp')
    model = bench_kwargs.pop('model')
    batch_size = bench_kwargs.pop('batch_size')

    bench_fns = (InferenceBenchmarkRunner,)
    prefixes = ('infer',)
    if args.bench == 'both':
        bench_fns = (
            InferenceBenchmarkRunner,
            TrainBenchmarkRunner
        )
        prefixes = ('infer', 'train')
    elif args.bench == 'train':
        bench_fns = TrainBenchmarkRunner,
        prefixes = 'train',
    elif args.bench.startswith('profile'):
        # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore
        if 'deepspeed' in args.bench:
            assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter"
            bench_kwargs['profiler'] = 'deepspeed'
        elif 'fvcore' in args.bench:
            assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter"
            bench_kwargs['profiler'] = 'fvcore'
        bench_fns = ProfileRunner,
        batch_size = 1

    model_results = OrderedDict(model=model)
    for prefix, bench_fn in zip(prefixes, bench_fns):
        run_results = _try_run(
            model,
            bench_fn,
            bench_kwargs=bench_kwargs,
            initial_batch_size=batch_size,
            no_batch_size_retry=args.no_retry,
        )
        if prefix and 'error' not in run_results:
            run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
        model_results.update(run_results)
        if 'error' in run_results:
            break
    if 'error' not in model_results:
        param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
        model_results.setdefault('param_count', param_count)
        model_results.pop('train_param_count', 0)
    return model_results