def main()

in mobile_cv/model_zoo/tools/jit_speed_benchmark.py [0:0]


def main():
    args = parse_args()

    init_env(args)

    # load jit model
    model = torch.jit.load(args.model)
    # prepare input
    input_data = parse_inputs(args)

    def run_model():
        return model(*input_data)

    if args.on_gpu:
        model.cuda()
        input_data = move_to_device(input_data, "cuda")

    # run warmup trials
    print("Warming up...")
    bench_model(
        run_model,
        args.warmup,
        args.check_freq,
        prefix="WARMUP",
        run_garbage_collector=args.run_garbage_collector,
    )
    # run real trials
    print("Benchmarking...")
    runtimes = bench_model(
        run_model,
        args.iter,
        args.check_freq,
        prefix="RUN",
        run_garbage_collector=args.run_garbage_collector,
    )

    # per op profiling
    maybe_run_autograd_profile(args, run_model)
    maybe_ai_pep_output(args, runtimes)