fn main()

in candle-metal-kernels/tmp/unary.rs [9:105]


fn main() {
    let device = Device::system_default().unwrap();
    let kernels = Kernels::new();

    let f32_1k = (0..1000).map(|_| rand::random::<f32>()).collect::<Vec<_>>();
    let f32_10k = (0..10000)
        .map(|_| rand::random::<f32>())
        .collect::<Vec<_>>();
    let f32_100k = (0..100000)
        .map(|_| rand::random::<f32>())
        .collect::<Vec<_>>();

    let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::<Vec<_>>();
    let f16_1k = f16_map(&f32_1k);
    let f16_10k = f16_map(&f32_10k);
    let f16_100k = f16_map(&f32_100k);

    let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::<Vec<_>>();
    let bf16_1k = bf16_map(&f32_1k);
    let bf16_10k = bf16_map(&f32_10k);
    let bf16_100k = bf16_map(&f32_100k);

    let f32_ckernels = [
        unary::contiguous::sin::FLOAT,
        unary::contiguous::cos::FLOAT,
        unary::contiguous::exp::FLOAT,
        unary::contiguous::sqr::FLOAT,
        unary::contiguous::sqrt::FLOAT,
        unary::contiguous::neg::FLOAT,
        unary::contiguous::copy::FLOAT,
    ];
    let f32_skernels = [
        unary::strided::sin::FLOAT,
        unary::strided::cos::FLOAT,
        unary::strided::exp::FLOAT,
        unary::strided::sqr::FLOAT,
        unary::strided::sqrt::FLOAT,
        unary::strided::neg::FLOAT,
        unary::strided::copy::FLOAT,
    ];
    let f16_ckernels = [
        unary::contiguous::sin::HALF,
        unary::contiguous::cos::HALF,
        unary::contiguous::exp::HALF,
        unary::contiguous::sqr::HALF,
        unary::contiguous::sqrt::HALF,
        unary::contiguous::neg::HALF,
        unary::contiguous::copy::HALF,
    ];
    let f16_skernels = [
        unary::strided::sin::HALF,
        unary::strided::cos::HALF,
        unary::strided::exp::HALF,
        unary::strided::sqr::HALF,
        unary::strided::sqrt::HALF,
        unary::strided::neg::HALF,
        unary::strided::copy::HALF,
    ];
    let bf16_ckernels = [
        unary::contiguous::sin::BFLOAT,
        unary::contiguous::cos::BFLOAT,
        unary::contiguous::exp::BFLOAT,
        unary::contiguous::sqr::BFLOAT,
        unary::contiguous::sqrt::BFLOAT,
        unary::contiguous::neg::BFLOAT,
        unary::contiguous::copy::BFLOAT,
    ];
    let bf16_skernels = [
        unary::strided::sin::BFLOAT,
        unary::strided::cos::BFLOAT,
        unary::strided::exp::BFLOAT,
        unary::strided::sqr::BFLOAT,
        unary::strided::sqrt::BFLOAT,
        unary::strided::neg::BFLOAT,
        unary::strided::copy::BFLOAT,
    ];

    println!(
        "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}",
        "dtype", "kernel", "size", "runs", "total time", "avg time"
    );

    // f32
    run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels);
    run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels);
    run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels);

    // f16
    run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels);
    run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels);
    run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels);

    // bf16
    run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels);
    run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels);
    run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels);
}