fn run_binary_bench()

in candle-metal-kernels/tmp/binary.rs [89:182]


fn run_binary_bench<T: Clone>(
    device: &Device,
    kernels: &Kernels,
    v: &[T],
    contiguous: [binary::contiguous::Kernel; 4],
    strided: [binary::strided::Kernel; 4],
) {
    let command_queue = device.new_command_queue();
    let options = MTLResourceOptions::StorageModeManaged;

    let iterations = 1000;
    let input = device.new_buffer_with_data(
        v.as_ptr() as *const core::ffi::c_void,
        core::mem::size_of_val(v) as u64,
        options,
    );
    let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options);

    // Contiguous
    for kernel_name in contiguous {
        let total_time = autoreleasepool(|| {
            let command_buffer = command_queue.new_command_buffer();
            let start = Instant::now();
            for _ in 0..iterations {
                call_binary_contiguous(
                    device,
                    &command_buffer,
                    kernels,
                    kernel_name,
                    v.len(),
                    &input,
                    &input,
                    &mut output,
                )
                .unwrap();
            }
            command_buffer.commit();
            command_buffer.wait_until_completed();

            start.elapsed()
        });
        println!(
            "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
            type_name::<T>().split("::").last().unwrap(),
            kernel_name.to_string(),
            v.len(),
            iterations,
            total_time,
            total_time / iterations
        );
    }

    // Strided
    let shape = vec![2, 5_000];
    let strides = vec![2, 1];
    let offset = 0;
    for kernel_name in strided {
        let total_time = autoreleasepool(|| {
            let command_buffer = command_queue.new_command_buffer();
            let start = Instant::now();
            for _ in 0..iterations {
                call_binary_strided(
                    device,
                    command_buffer,
                    &kernels,
                    kernel_name,
                    &shape,
                    &input,
                    &strides,
                    offset,
                    &input,
                    &strides,
                    offset,
                    &mut output,
                )
                .unwrap();
            }
            command_buffer.commit();
            command_buffer.wait_until_completed();

            start.elapsed()
        });

        println!(
            "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}",
            type_name::<T>().split("::").last().unwrap(),
            kernel_name.to_string(),
            v.len(),
            iterations,
            total_time,
            total_time / iterations
        );
    }
}