lib/build-sets.nix (137 lines of code) (raw):

{ nixpkgs, system, hf-nix, torchVersions, }: let inherit (nixpkgs) lib; overlay = import ../overlay.nix; inherit (import ./torch-version-utils.nix { inherit lib; }) flattenSystems isCuda isMetal isRocm ; # All build configurations supported by Torch. buildConfigs = system: let filterMap = f: xs: builtins.filter (x: x != null) (builtins.map f xs); in filterMap (version: if version.system == system then version else null) ( flattenSystems torchVersions ); cudaVersions = let withCuda = builtins.filter (torchVersion: torchVersion ? cudaVersion) torchVersions; in builtins.map (torchVersion: torchVersion.cudaVersion) withCuda; rocmVersions = let withRocm = builtins.filter (torchVersion: torchVersion ? rocmVersion) torchVersions; in builtins.map (torchVersion: torchVersion.rocmVersion) withRocm; flattenVersion = version: lib.replaceStrings [ "." ] [ "_" ] (lib.versions.pad 2 version); # An overlay that overides CUDA to the given version. overlayForCudaVersion = cudaVersion: self: super: { cudaPackages = super."cudaPackages_${flattenVersion cudaVersion}"; }; overlayForRocmVersion = rocmVersion: self: super: { rocmPackages = super."rocmPackages_${flattenVersion rocmVersion}"; }; # Construct the nixpkgs package set for the given versions. pkgsForVersions = buildConfig@{ cudaVersion ? null, metal ? false, rocmVersion ? null, torchVersion, cxx11Abi, system, upstreamVariant ? false, }: let pkgs = if isCuda buildConfig then pkgsByCudaVer.${cudaVersion} else if isRocm buildConfig then pkgsByRocmVer.${rocmVersion} else if isMetal buildConfig then pkgsForMetal else throw "No compute framework set in Torch version"; torch = pkgs.python3.pkgs."torch_${flattenVersion torchVersion}".override { inherit cxx11Abi; }; in { inherit buildConfig pkgs torch upstreamVariant ; }; pkgsForMetal = import nixpkgs { inherit system; overlays = [ hf-nix overlay ]; }; pkgsForRocm = import nixpkgs { inherit system; config = { allowUnfree = true; rocmSupport = true; }; overlays = [ hf-nix overlay ]; }; # Instantiate nixpkgs for the given CUDA versions. Returns # an attribute set like `{ "12.4" = <nixpkgs with 12.4>; ... }`. pkgsForCudaVersions = cudaVersions: builtins.listToAttrs ( map (cudaVersion: { name = cudaVersion; value = import nixpkgs { inherit system; config = { allowUnfree = true; cudaSupport = true; }; overlays = [ hf-nix overlay (overlayForCudaVersion cudaVersion) ]; }; }) cudaVersions ); pkgsByCudaVer = pkgsForCudaVersions cudaVersions; pkgsForRocmVersions = rocmVersions: builtins.listToAttrs ( map (rocmVersion: { name = rocmVersion; value = import nixpkgs { inherit system; config = { allowUnfree = true; rocmSupport = true; }; overlays = [ hf-nix overlay (overlayForRocmVersion rocmVersion) ]; }; }) rocmVersions ); pkgsByRocmVer = pkgsForRocmVersions rocmVersions; in map pkgsForVersions (buildConfigs system)