build2cmake/src/main.rs (322 lines of code) (raw):
use std::{
fs::{self, File},
io::{BufWriter, Read, Write},
path::{Path, PathBuf},
};
use clap::{Parser, Subcommand};
use eyre::{bail, ensure, Context, Result};
use minijinja::Environment;
mod torch;
use torch::{write_torch_ext_cuda, write_torch_ext_metal, write_torch_ext_universal};
mod config;
use config::{Backend, Build, BuildCompat};
mod fileset;
use fileset::FileSet;
mod version;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Debug, Subcommand)]
enum Commands {
/// Generate CMake files for Torch extension builds.
GenerateTorch {
#[arg(name = "BUILD_TOML")]
build_toml: PathBuf,
/// The directory to write the generated files to
/// (directory of `BUILD_TOML` when absent).
#[arg(name = "TARGET_DIR")]
target_dir: Option<PathBuf>,
/// Force-overwrite existing files.
#[arg(short, long)]
force: bool,
/// This is an optional unique identifier that is suffixed to the
/// kernel name to avoid name collisions. (e.g. Git SHA)
#[arg(long)]
ops_id: Option<String>,
#[arg(long)]
backend: Option<Backend>,
},
/// Update a `build.toml` to the current format.
UpdateBuild {
#[arg(name = "BUILD_TOML")]
build_toml: PathBuf,
},
/// Validate the build.toml file.
Validate {
#[arg(name = "BUILD_TOML")]
build_toml: PathBuf,
},
/// Clean generated artifacts.
Clean {
#[arg(name = "BUILD_TOML")]
build_toml: PathBuf,
/// The directory to clean from (directory of `BUILD_TOML` when absent).
#[arg(name = "TARGET_DIR")]
target_dir: Option<PathBuf>,
/// Show what would be deleted without actually deleting.
#[arg(short, long)]
dry_run: bool,
/// Force deletion without confirmation.
#[arg(short, long)]
force: bool,
/// This is an optional unique identifier that is suffixed to the
/// kernel name to avoid name collisions. (e.g. Git SHA)
#[arg(long)]
ops_id: Option<String>,
},
}
fn main() -> Result<()> {
let args = Cli::parse();
match args.command {
Commands::GenerateTorch {
backend,
build_toml,
force,
target_dir,
ops_id,
} => generate_torch(backend, build_toml, target_dir, force, ops_id),
Commands::UpdateBuild { build_toml } => update_build(build_toml),
Commands::Validate { build_toml } => {
parse_and_validate(build_toml)?;
Ok(())
}
Commands::Clean {
build_toml,
target_dir,
dry_run,
force,
ops_id,
} => clean(build_toml, target_dir, dry_run, force, ops_id),
}
}
fn generate_torch(
backend: Option<Backend>,
build_toml: PathBuf,
target_dir: Option<PathBuf>,
force: bool,
ops_id: Option<String>,
) -> Result<()> {
let target_dir = check_or_infer_target_dir(&build_toml, target_dir)?;
let build_compat = parse_and_validate(build_toml)?;
if matches!(build_compat, BuildCompat::V1(_)) {
eprintln!(
"build.toml is in the deprecated V1 format, use `build2cmake update-build` to update."
)
}
let build: Build = build_compat
.try_into()
.context("Cannot update build configuration")?;
let mut env = Environment::new();
env.set_trim_blocks(true);
minijinja_embed::load_templates!(&mut env);
let backend = match (backend, build.general.universal) {
(None, true) => {
let file_set = write_torch_ext_universal(&env, &build, target_dir.clone(), ops_id)?;
file_set.write(&target_dir, force)?;
return Ok(());
}
(Some(backend), true) => bail!("Universal kernel, cannot generate for backend {}", backend),
(Some(backend), false) => {
if !build.has_kernel_with_backend(&backend) {
bail!("No kernels found for backend {}", backend);
}
backend
}
(None, false) => {
let mut kernel_backends = build.backends();
let backend = if let Some(backend) = kernel_backends.pop_first() {
backend
} else {
bail!("No kernels found in build.toml");
};
if !kernel_backends.is_empty() {
let kernel_backends: Vec<_> = build
.backends()
.into_iter()
.map(|backend| backend.to_string())
.collect();
bail!(
"Multiple supported backends found in build.toml: {}. Please specify one with --backend.",
kernel_backends.join(", ")
);
}
backend
}
};
let file_set = match backend {
Backend::Cuda | Backend::Rocm => {
write_torch_ext_cuda(&env, backend, &build, target_dir.clone(), ops_id)?
}
Backend::Metal => write_torch_ext_metal(&env, &build, target_dir.clone(), ops_id)?,
};
file_set.write(&target_dir, force)?;
Ok(())
}
fn update_build(build_toml: PathBuf) -> Result<()> {
let build_compat: BuildCompat = parse_and_validate(&build_toml)?;
if matches!(build_compat, BuildCompat::V2(_)) {
return Ok(());
}
let build: Build = build_compat
.try_into()
.context("Cannot update build configuration")?;
let pretty_toml = toml::to_string_pretty(&build)?;
let mut writer =
BufWriter::new(File::create(&build_toml).wrap_err_with(|| {
format!("Cannot open {} for writing", build_toml.to_string_lossy())
})?);
writer
.write_all(pretty_toml.as_bytes())
.wrap_err_with(|| format!("Cannot write to {}", build_toml.to_string_lossy()))?;
Ok(())
}
fn check_or_infer_target_dir(
build_toml: impl AsRef<Path>,
target_dir: Option<PathBuf>,
) -> Result<PathBuf> {
let build_toml = build_toml.as_ref();
match target_dir {
Some(target_dir) => {
ensure!(
target_dir.is_dir(),
"`{}` is not a directory",
target_dir.to_string_lossy()
);
Ok(target_dir)
}
None => {
let absolute = std::path::absolute(build_toml)?;
match absolute.parent() {
Some(parent) => Ok(parent.to_owned()),
None => bail!(
"Cannot get parent path of `{}`",
build_toml.to_string_lossy()
),
}
}
}
}
fn parse_and_validate(build_toml: impl AsRef<Path>) -> Result<BuildCompat> {
let build_toml = build_toml.as_ref();
let mut toml_data = String::new();
File::open(build_toml)
.wrap_err_with(|| format!("Cannot open {} for reading", build_toml.to_string_lossy()))?
.read_to_string(&mut toml_data)
.wrap_err_with(|| format!("Cannot read from {}", build_toml.to_string_lossy()))?;
let build_compat: BuildCompat = toml::from_str(&toml_data)
.wrap_err_with(|| format!("Cannot parse TOML in {}", build_toml.to_string_lossy()))?;
Ok(build_compat)
}
fn clean(
build_toml: PathBuf,
target_dir: Option<PathBuf>,
dry_run: bool,
force: bool,
ops_id: Option<String>,
) -> Result<()> {
let target_dir = check_or_infer_target_dir(&build_toml, target_dir)?;
let build_compat = parse_and_validate(build_toml)?;
if matches!(build_compat, BuildCompat::V1(_)) {
eprintln!(
"build.toml is in the deprecated V1 format, use `build2cmake update-build` to update."
)
}
let build: Build = build_compat
.try_into()
.context("Cannot update build configuration")?;
let mut env = Environment::new();
env.set_trim_blocks(true);
minijinja_embed::load_templates!(&mut env);
let generated_files = get_generated_files(&env, &build, target_dir.clone(), ops_id)?;
if generated_files.is_empty() {
eprintln!("No generated artifacts found to clean.");
return Ok(());
}
if dry_run {
println!("Files that would be deleted:");
for file in &generated_files {
if file.exists() {
println!(" {}", file.to_string_lossy());
}
}
return Ok(());
}
let existing_files: Vec<_> = generated_files.iter().filter(|f| f.exists()).collect();
if existing_files.is_empty() {
eprintln!("No generated artifacts found to clean.");
return Ok(());
}
if !force {
println!("Files to be deleted:");
for file in &existing_files {
println!(" {}", file.to_string_lossy());
}
print!("Continue? [y/N] ");
std::io::stdout().flush()?;
let mut response = String::new();
std::io::stdin().read_line(&mut response)?;
let response = response.trim().to_lowercase();
if response != "y" && response != "yes" {
eprintln!("Aborted.");
return Ok(());
}
}
let mut deleted_count = 0;
let mut errors = Vec::new();
for file in existing_files {
match fs::remove_file(file) {
Ok(_) => {
deleted_count += 1;
println!("Deleted: {}", file.to_string_lossy());
}
Err(e) => {
errors.push(format!(
"Failed to delete {}: {}",
file.to_string_lossy(),
e
));
}
}
}
// Clean up empty directories
let dirs_to_check = [
target_dir.join("cmake"),
target_dir.join("torch-ext").join(&build.general.name),
target_dir.join("torch-ext"),
];
for dir in dirs_to_check {
if dir.exists() && is_empty_dir(&dir)? {
match fs::remove_dir(&dir) {
Ok(_) => println!("Removed empty directory: {}", dir.to_string_lossy()),
Err(e) => eyre::bail!("Failed to remove directory `{}`: {e:?}", dir.display()),
}
}
}
if !errors.is_empty() {
for error in errors {
eprintln!("Error: {error}");
}
bail!("Some files could not be deleted");
}
println!("Cleaned {deleted_count} generated artifacts.");
Ok(())
}
fn get_generated_files(
env: &Environment,
build: &Build,
target_dir: PathBuf,
ops_id: Option<String>,
) -> Result<Vec<PathBuf>> {
let mut all_set = FileSet::new();
for backend in build.backends() {
let set = match backend {
Backend::Cuda | Backend::Rocm => {
write_torch_ext_cuda(env, backend, build, target_dir.clone(), ops_id.clone())?
}
Backend::Metal => {
write_torch_ext_metal(env, build, target_dir.clone(), ops_id.clone())?
}
};
all_set.extend(set);
}
if build.general.universal {
let set = write_torch_ext_universal(env, build, target_dir, ops_id)?;
all_set.extend(set);
}
Ok(all_set.into_names())
}
fn is_empty_dir(dir: &Path) -> Result<bool> {
if !dir.is_dir() {
return Ok(false);
}
let mut entries = fs::read_dir(dir)?;
Ok(entries.next().is_none())
}