lib/build.nix (266 lines of code) (raw):
{
lib,
# List of build sets. Each build set is a attrset of the form
#
# { pkgs = <nixpkgs>, torch = <torch drv> }
#
# The Torch derivation is built as-is. So e.g. the ABI version should
# already be set.
buildSets,
}:
let
abi = torch: if torch.passthru.cxx11Abi then "cxx11" else "cxx98";
torchBuildVersion = import ./build-version.nix;
supportedCudaCapabilities = builtins.fromJSON (
builtins.readFile ../build2cmake/src/cuda_supported_archs.json
);
inherit (import ./torch-version-utils.nix { inherit lib; }) isCuda isMetal isRocm;
in
rec {
resolveDeps = import ./deps.nix { inherit lib; };
readToml = path: builtins.fromTOML (builtins.readFile path);
validateBuildConfig =
buildConfig:
let
kernels = lib.attrValues (buildConfig.kernel or { });
hasOldUniversal = builtins.hasAttr "universal" (buildConfig.torch or { });
hasLanguage = lib.any (kernel: kernel ? language) kernels;
in
assert lib.assertMsg (!hasOldUniversal && !hasLanguage) ''
build.toml seems to be of an older version, update it with:
build2cmake update-build build.toml'';
buildConfig;
backends =
buildConfig:
let
kernels = lib.attrValues (buildConfig.kernel or { });
kernelBackend = kernel: kernel.backend;
init = {
cuda = false;
metal = false;
rocm = false;
};
in
lib.foldl (backends: kernel: backends // { ${kernelBackend kernel} = true; }) init kernels;
readBuildConfig = path: validateBuildConfig (readToml (path + "/build.toml"));
srcFilter =
src: name: type:
type == "directory" || lib.any (suffix: lib.hasSuffix suffix name) src;
# Source set function to create a fileset for a path
mkSourceSet = import ./source-set.nix { inherit lib; };
# Filter buildsets that are applicable to a given kernel build config.
applicableBuildSets =
buildConfig: buildSets:
let
backends' = backends buildConfig;
minCuda = buildConfig.general.cuda-minver or "11.8";
maxCuda = buildConfig.general.cuda-maxver or "99.9";
versionBetween =
minver: maxver: ver:
builtins.compareVersions ver minver >= 0 && builtins.compareVersions ver maxver <= 0;
supportedBuildSet =
buildSet:
let
backendSupported =
(isCuda buildSet.buildConfig && backends'.cuda)
|| (isRocm buildSet.buildConfig && backends'.rocm)
|| (isMetal buildSet.buildConfig && backends'.metal)
|| (buildConfig.general.universal or false);
cudaVersionSupported =
!(isCuda buildSet.buildConfig)
|| versionBetween minCuda maxCuda buildSet.pkgs.cudaPackages.cudaMajorMinorVersion;
in
backendSupported && cudaVersionSupported;
in
builtins.filter supportedBuildSet buildSets;
# Build a single Torch extension.
buildTorchExtension =
{
buildConfig,
pkgs,
torch,
upstreamVariant,
}:
{
path,
rev,
stripRPath ? false,
oldLinuxCompat ? false,
}:
let
inherit (lib) fileset;
buildConfig = readBuildConfig path;
kernels = buildConfig.kernel or { };
extraDeps = resolveDeps {
inherit pkgs torch;
deps = lib.unique (lib.flatten (lib.mapAttrsToList (_: buildConfig: buildConfig.depends) kernels));
};
# Use the mkSourceSet function to get the source
src = mkSourceSet path;
# Set number of threads to the largest number of capabilities.
listMax = lib.foldl' lib.max 1;
nvccThreads = listMax (
lib.mapAttrsToList (
_: buildConfig: builtins.length (buildConfig.cuda-capabilities or supportedCudaCapabilities)
) buildConfig.kernel
);
stdenv =
if pkgs.stdenv.hostPlatform.isDarwin then
pkgs.stdenv
else if oldLinuxCompat then
pkgs.stdenvGlibc_2_27
else
pkgs.cudaPackages.backendStdenv;
in
if buildConfig.general.universal then
# No torch extension sources? Treat it as a noarch package.
pkgs.callPackage ./torch-extension-noarch ({
inherit src rev torch;
extensionName = buildConfig.general.name;
})
else
pkgs.callPackage ./torch-extension ({
inherit
extraDeps
nvccThreads
src
stdenv
stripRPath
torch
rev
;
extensionName = buildConfig.general.name;
doAbiCheck = oldLinuxCompat;
});
# Build multiple Torch extensions.
buildNixTorchExtensions =
{ path, rev }:
let
extensionForTorch =
{ path, rev }:
buildSet: {
name = torchBuildVersion buildSet;
value = buildTorchExtension buildSet { inherit path rev; };
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
# Build multiple Torch extensions.
buildDistTorchExtensions =
{
buildSets,
path,
rev,
}:
let
extensionForTorch =
{ path, rev }:
buildSet: {
name = torchBuildVersion buildSet;
value = buildTorchExtension buildSet {
inherit path rev;
stripRPath = true;
oldLinuxCompat = true;
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map (extensionForTorch { inherit path rev; }) filteredBuildSets);
buildTorchExtensionBundle =
{ path, rev }:
let
# We just need to get any nixpkgs for use by the path join.
pkgs = (builtins.head buildSets).pkgs;
upstreamBuildSets = builtins.filter (buildSet: buildSet.upstreamVariant) buildSets;
extensions = buildDistTorchExtensions {
inherit path rev;
buildSets = upstreamBuildSets;
};
buildConfig = readBuildConfig path;
namePaths =
if buildConfig.general.universal then
# Noarch, just get the first extension.
{ "torch-universal" = builtins.head (builtins.attrValues extensions); }
else
lib.mapAttrs (name: pkg: toString pkg) extensions;
in
import ./join-paths {
inherit pkgs namePaths;
name = "torch-ext-bundle";
};
# Get a development shell with the extension in PYTHONPATH. Handy
# for running tests.
torchExtensionShells =
{
path,
rev,
pythonCheckInputs,
pythonNativeCheckInputs,
}:
let
shellForBuildSet =
{ path, rev }:
buildSet:
let
pkgs = buildSet.pkgs;
rocmSupport = pkgs.config.rocmSupport or false;
stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv;
mkShell = pkgs.mkShell.override { inherit stdenv; };
in
{
name = torchBuildVersion buildSet;
value = mkShell {
nativeBuildInputs = with pkgs; pythonNativeCheckInputs python3.pkgs;
buildInputs =
with pkgs;
[
buildSet.torch
python3.pkgs.pytest
]
++ (pythonCheckInputs python3.pkgs);
shellHook = ''
export PYTHONPATH=''${PYTHONPATH}:${buildTorchExtension buildSet { inherit path rev; }}
'';
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map (shellForBuildSet { inherit path rev; }) filteredBuildSets);
torchDevShells =
{
path,
rev,
pythonCheckInputs,
pythonNativeCheckInputs,
}:
let
shellForBuildSet =
buildSet:
let
pkgs = buildSet.pkgs;
rocmSupport = pkgs.config.rocmSupport or false;
stdenv = if rocmSupport then pkgs.stdenv else pkgs.cudaPackages.backendStdenv;
mkShell = pkgs.mkShell.override { inherit stdenv; };
in
{
name = torchBuildVersion buildSet;
value = mkShell {
nativeBuildInputs =
with pkgs;
[
build2cmake
kernel-abi-check
]
++ (pythonNativeCheckInputs python3.pkgs);
buildInputs = with pkgs; [ python3.pkgs.pytest ] ++ (pythonCheckInputs python3.pkgs);
inputsFrom = [ (buildTorchExtension buildSet { inherit path rev; }) ];
env = lib.optionalAttrs rocmSupport {
PYTORCH_ROCM_ARCH = lib.concatStringsSep ";" buildSet.torch.rocmArchs;
HIP_PATH = pkgs.rocmPackages.clr;
};
};
};
filteredBuildSets = applicableBuildSets (readBuildConfig path) buildSets;
in
builtins.listToAttrs (lib.map shellForBuildSet filteredBuildSets);
}