def bench_vanilla_sddmm_cuda()

in benchmark/bench_vanilla_sddmm.py [0:0]


def bench_vanilla_sddmm_cuda(adj_scipy_coo, feat_len):
    num_rows = adj_scipy_coo.shape[0]
    num_cols = adj_scipy_coo.shape[1]
    num_edges = adj_scipy_coo.nnz

    def _bench_vanilla_sddmm_cuda(num_cuda_blocks):
        vanilla_sddmm_module = VanillaSDDMMcuda(adj_scipy_coo)
        SrcFeat = te.placeholder((num_cols, feat_len))
        DstFeat = te.placeholder((num_rows, feat_len))
        input_placeholders = [SrcFeat, DstFeat]
        compute_args = {}
        schedule_args = {'num_cuda_blocks': num_cuda_blocks}
        vanilla_sddmm_module.build(input_placeholders, compute_args, schedule_args)
        src_feat_np = np.random.random(get_const_tuple(SrcFeat.shape)).astype('float32')
        dst_feat_np = np.random.random(get_const_tuple(DstFeat.shape)).astype('float32')
        src_feat_tvm = tvm.nd.array(src_feat_np, vanilla_sddmm_module.ctx)
        dst_feat_tvm = tvm.nd.array(dst_feat_np, vanilla_sddmm_module.ctx)
        input_tvm_ndarrays = [src_feat_tvm, dst_feat_tvm]
        num_runs = 5
        tcost = vanilla_sddmm_module.measure_average_time(input_tvm_ndarrays, num_runs)
        print("average time of {} runs: {} ms".format(num_runs, tcost * 1000))

    for num_cuda_blocks in exp_range(64, min(262144, num_edges // 32), 4):
        print()
        print("num_cuda_blocks:", num_cuda_blocks)
        _bench_vanilla_sddmm_cuda(num_cuda_blocks)