def profile_cuda_kernels()

in benchmarks/transformer_fusion_patterns/benchmark_helper.py [0:0]


def profile_cuda_kernels(fn, args, string_id="Model time"):
    print("################################################")
    print(f"#### Profiling for {string_id} starts #########")
    print("################################################")
    warmup = 50
    old_args = args[:]
    n_repeats = 1
    n_layers = 1
    ref = fn(*old_args)
    gO = torch.rand_like(ref)
    for _ in range(0, warmup // n_layers):
        args = list(old_args[:])
        ref = fn(*args)
        ref.backward(gO)

    torch.cuda.synchronize()

    # Forward profile
    def fwd_run():
        for _ in range(0, n_repeats // n_layers):
            args = list(old_args[:])
            for arg in args:
                arg.grad = None
            ref = fn(*args)

    print(f"###### Forward profile for {string_id} starts #####")
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("baseline"):
            fwd_run()
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
    print(f"###### Forward profile for {string_id} ends #####")

    # Backward profile
    def bwd_run():
        for _ in range(0, n_repeats // n_layers):
            args = list(old_args[:])
            for arg in args:
                arg.grad = None
            ref = fn(*args)

            print(f"###### Backward profile for {string_id} starts #####")
            torch.cuda.synchronize()
            with profile(
                activities=[ProfilerActivity.CUDA], record_shapes=True
            ) as prof:
                with record_function("baseline"):
                    ref.backward(gO)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
            torch.cuda.synchronize()
            print(f"###### Backward profile for {string_id} ends #####")

    bwd_run()
    print("################################################")
    print(f"#### Profiling for {string_id} ends #########")
    print("################################################\n\n\n\n")