csrc/custom_ops/torch_bindings.cpp (17 lines of code) (raw):

#include "custom_ops.h" #include <torch/library.h> TORCH_LIBRARY(arctic_inference, ops) { ops.def( "reshape_and_cache_flash_bulk(Tensor keys," " Tensor values," " Tensor(c!)[] key_caches," " Tensor(d!)[] value_caches," " Tensor slot_mapping," " str kv_cache_dtype," " Tensor(e)[] k_scales," " Tensor(f)[] v_scales," " int num_heads," " int head_size) -> ()"); ops.impl("reshape_and_cache_flash_bulk", torch::kCUDA, &reshape_and_cache_flash_bulk); }