lib/build-variants.nix (54 lines of code) (raw):
{ lib, torchVersions }:
let
inherit (import ./torch-version-utils.nix { inherit lib; })
flattenSystems
isCuda
isMetal
isRocm
;
in
rec {
computeFramework =
buildConfig:
if buildConfig ? cudaVersion then
"cuda"
else if buildConfig ? metal then
"metal"
else if buildConfig ? "rocmVersion" then
"rocm"
else
throw "Could not find compute framework: no CUDA or ROCm version specified and Metal is not enabled";
# Upstream build variants.
buildVariants =
let
inherit (import ./version-utils.nix { inherit lib; }) abiString flattenVersion;
computeString =
version:
if isCuda version then
"cu${flattenVersion (lib.versions.majorMinor version.cudaVersion)}"
else if isRocm version then
"rocm${flattenVersion (lib.versions.majorMinor version.rocmVersion)}"
else if isMetal version then
"metal"
else
throw "No compute framework set in Torch version";
buildName =
version:
if version.system == "aarch64-darwin" then
"torch${flattenVersion version.torchVersion}-${computeString version}-${version.system}"
else
"torch${flattenVersion version.torchVersion}-${abiString version.cxx11Abi}-${computeString version}-${version.system}";
upstreamVersions = lib.filter (version: version.upstreamVariant or false);
in
lib.foldl' (
acc: version:
let
path = [
version.system
(computeFramework version)
];
pathVersions = lib.attrByPath path [ ] acc ++ [ (buildName version) ];
in
lib.recursiveUpdate acc (lib.setAttrByPath path pathVersions)
) { } (flattenSystems (upstreamVersions torchVersions));
}