csrc/custom_ops/custom_ops.h (17 lines of code) (raw):
#pragma once
#include <optional>
#include <vector>
#include <torch/all.h>
#include <torch/library.h>
void reshape_and_cache_flash_bulk(
torch::Tensor& keys,
torch::Tensor& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
std::vector<torch::Tensor> const& k_scales,
std::vector<torch::Tensor> const& v_scales,
int64_t num_heads,
int64_t head_size
);