flake.nix (193 lines of code) (raw):

{ description = "Kernel builder"; inputs = { flake-utils.url = "github:numtide/flake-utils"; nixpkgs.follows = "hf-nix/nixpkgs"; flake-compat.url = "github:edolstra/flake-compat"; hf-nix.url = "github:huggingface/hf-nix"; }; outputs = { self, flake-compat, flake-utils, hf-nix, nixpkgs, }: let systems = with flake-utils.lib.system; [ aarch64-darwin aarch64-linux x86_64-linux ]; torchVersions' = import ./versions.nix; # Create an attrset { "<system>" = [ <buildset> ...]; ... }. mkBuildSetsPerSystem = torchVersions: builtins.listToAttrs ( builtins.map (system: { name = system; value = import ./lib/build-sets.nix { inherit nixpkgs system torchVersions; hf-nix = hf-nix.overlays.default; }; }) systems ); defaultBuildSetsPerSystem = mkBuildSetsPerSystem torchVersions'; mkBuildPerSystem = buildSetPerSystem: builtins.mapAttrs ( system: buildSet: import lib/build.nix { inherit (nixpkgs) lib; buildSets = buildSetPerSystem.${system}; } ) buildSetPerSystem; defaultBuildPerSystem = mkBuildPerSystem defaultBuildSetsPerSystem; # The lib output consists of two parts: # # - Per-system build functions. # - `genFlakeOutputs`, which can be used by downstream flakes to make # standardized outputs (for all supported systems). lib = { allBuildVariantsJSON = let buildVariants = (import ./lib/build-variants.nix { inherit (nixpkgs) lib; torchVersions = torchVersions'; }).buildVariants; in builtins.toJSON buildVariants; genFlakeOutputs = { path, rev, pythonCheckInputs ? pkgs: [ ], pythonNativeCheckInputs ? pkgs: [ ], torchVersions ? torchVersions', }: let buildSetPerSystem' = mkBuildSetsPerSystem torchVersions; buildPerSystem = mkBuildPerSystem buildSetPerSystem'; in flake-utils.lib.eachSystem systems ( system: let build = buildPerSystem.${system}; revUnderscored = builtins.replaceStrings [ "-" ] [ "_" ] rev; pkgs = nixpkgs.legacyPackages.${system}; shellTorch = if system == "aarch64-darwin" then "torch27-metal-${system}" else "torch27-cxx11-cu126-${system}"; in { devShells = rec { default = devShells.${shellTorch}; test = testShells.${shellTorch}; devShells = build.torchDevShells { inherit path pythonCheckInputs pythonNativeCheckInputs; rev = revUnderscored; }; testShells = build.torchExtensionShells { inherit path pythonCheckInputs pythonNativeCheckInputs; rev = revUnderscored; }; }; packages = rec { default = bundle; bundle = build.buildTorchExtensionBundle { inherit path; rev = revUnderscored; }; redistributable = build.buildDistTorchExtensions { inherit path; buildSets = buildSetPerSystem'.${system}; rev = revUnderscored; }; buildTree = let build2cmake = self.packages.${system}.build2cmake; src = build.mkSourceSet path; in pkgs.runCommand "torch-extension-build-tree" { nativeBuildInputs = [ build2cmake ]; inherit src; meta = { description = "Build tree for torch extension with source files and CMake configuration"; }; } '' # Copy sources install -dm755 $out/src cp -r $src/. $out/src/ # Generate cmake files build2cmake generate-torch --ops-id "${revUnderscored}" $src/build.toml $out --force ''; }; } ); } // defaultBuildPerSystem; in flake-utils.lib.eachSystem systems ( system: let # Plain nixkpgs that we use to access utility funtions. pkgs = import nixpkgs { inherit system; }; inherit (nixpkgs) lib; buildVersion = import ./lib/build-version.nix; buildSets = defaultBuildSetsPerSystem.${system}; in rec { formatter = pkgs.nixfmt-tree; packages = rec { build2cmake = pkgs.callPackage ./pkgs/build2cmake { }; update-build = pkgs.writeShellScriptBin "update-build" '' ${build2cmake}/bin/build2cmake update-build ''${1:-build.toml} ''; # This package set is exposed so that we can prebuild the Torch versions. torch = builtins.listToAttrs ( map (buildSet: { name = buildVersion buildSet; value = buildSet.torch; }) buildSets ); # Dependencies that should be cached. forCache = let filterDist = lib.filter (output: output != "dist"); # Get all `torch` outputs except for `dist`. Not all outputs # are dependencies of `out`, but we'll need the `cxxdev` and # `dev` outputs for kernel builds. torchOutputs = builtins.listToAttrs ( lib.flatten ( # Map over build sets. map ( buildSet: # Map over all outputs of `torch` in a buildset. map (output: { name = "${buildVersion buildSet}-${output}"; value = buildSet.torch.${output}; }) (filterDist buildSet.torch.outputs) ) buildSets ) ); oldLinuxStdenvs = builtins.listToAttrs ( map (buildSet: { name = "stdenv-${buildVersion buildSet}"; value = buildSet.pkgs.stdenvGlibc_2_27; }) buildSets ); in pkgs.linkFarm "packages-for-cache" ( torchOutputs // lib.optionalAttrs nixpkgs.legacyPackages.${system}.stdenv.isLinux oldLinuxStdenvs ); }; } ) // { inherit lib; }; }