pkgs/python-modules/marlin-kernels/default.nix (77 lines of code) (raw):
{
lib,
stdenv,
fetchFromGitHub,
buildPythonPackage,
autoAddDriverRunpath,
cmake,
ninja,
which,
packaging,
setuptools,
wheel,
cudaPackages,
torch,
}:
let
cutlass = fetchFromGitHub {
owner = "NVIDIA";
repo = "cutlass";
rev = "refs/tags/v3.6.0";
hash = "sha256-FbMVqR4eZyum5w4Dj5qJgBPOS66sTem/qKZjYIK/7sg=";
};
in
buildPythonPackage rec {
pname = "marlin-kernels";
version = "0.3.7";
src = fetchFromGitHub {
owner = "danieldk";
repo = pname;
rev = "v${version}";
hash = "sha256-xh4EnjFSQ3VrGGOsZOMbmIwfFmY6N/KghjmXMTn4tfc=";
};
patches = [
./setup.py-nix-support-respect-cmakeFlags.patch
];
stdenv = cudaPackages.backendStdenv;
nativeBuildInputs = with cudaPackages; [
autoAddDriverRunpath
cmake
cuda_nvcc
ninja
which
];
build-system = [
packaging
setuptools
wheel
];
buildInputs = with cudaPackages; [
cuda_cccl
cuda_cudart
cuda_nvtx
libcublas
libcusolver
libcusparse
];
env = {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
# Build time is largely determined by a few kernels. So opt for parallelism
# for every capability.
NVCC_THREADS = builtins.length torch.cudaCapabilities;
TORCH_CUDA_ARCH_LIST = lib.concatStringsSep ";" torch.cudaCapabilities;
};
propagatedBuildInputs = [ torch ];
# cmake/ninja are used for parallel builds, but we don't want the
# cmake configure hook to kick in.
dontUseCmakeConfigure = true;
cmakeFlags = [ (lib.cmakeFeature "FETCHCONTENT_SOURCE_DIR_CUTLASS" "${lib.getDev cutlass}") ];
# We don't have any tests in this package (yet).
doCheck = false;
pythonImportsCheck = [ "marlin_kernels" ];
meta = with lib; {
description = "Marlin quantization kernels";
license = licenses.asl20;
platforms = platforms.linux;
};
}