build.rs (45 lines of code) (raw):
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use std::path::PathBuf;
const KERNEL_FILES: [&str; 1] = ["kernels/rotary.cu"];
fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed={kernel_file}");
}
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_ROTARY_BUILD_DIR") {
Err(_) =>
{
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) => {
let path = PathBuf::from(build_dir);
path.canonicalize().expect(&format!(
"Directory doesn't exists: {} (the current directory is {})",
&path.display(),
std::env::current_dir()?.display()
))
}
};
let kernels : Vec<_>= KERNEL_FILES.iter().collect();
let builder = bindgen_cuda::Builder::default().kernel_paths(kernels).out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--ptxas-options=-v")
.arg("--verbose");
let out_file = build_dir.join("librotary.a");
builder.build_lib(out_file);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=rotary");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
Ok(())
}