def bench_vanilla_spmm_x86()

in benchmark/bench_vanilla_spmm.py [0:0]


def bench_vanilla_spmm_x86(adj_scipy_csr, feat_len):
    num_rows = adj_scipy_csr.shape[0]
    num_cols = adj_scipy_csr.shape[1]

    def _bench_vanilla_spmm_x86(num_col_partitions, num_feat_partitions):
        vanilla_spmm_module = VanillaSpMMx86(adj_scipy_csr, num_col_partitions)
        SrcFeat = te.placeholder((num_cols, feat_len))
        input_placeholders = [SrcFeat]
        compute_args = {'num_feat_partitions': num_feat_partitions}
        schedule_args = {}
        vanilla_spmm_module.build(input_placeholders, compute_args, schedule_args)
        src_feat_np = np.random.random(get_const_tuple(SrcFeat.shape)).astype('float32')
        src_feat_tvm = tvm.nd.array(src_feat_np, vanilla_spmm_module.ctx)
        input_tvm_ndarrays = [src_feat_tvm]
        num_runs = 5
        tcost = vanilla_spmm_module.measure_average_time(input_tvm_ndarrays, num_runs)
        print("average time of {} runs: {} sec".format(num_runs, tcost))

    for num_col_partitions in exp_range(1, 32, 2):
        for num_feat_partitions in exp_range(1, feat_len // 16, 2):
            print()
            print("num_col_partitions:", num_col_partitions)
            print("num_feat_partitions:", num_feat_partitions)
            _bench_vanilla_spmm_x86(num_col_partitions, num_feat_partitions)