pkgs/python-modules/flash-attn-layer-norm/default.nix (64 lines of code) (raw):
{
lib,
stdenv,
fetchFromGitHub,
buildPythonPackage,
autoAddDriverRunpath,
cmake,
git,
ninja,
packaging,
psutil,
which,
cudaPackages,
torch,
}:
buildPythonPackage rec {
pname = "flash-attn-layer-norm";
version = "2.6.3";
src = fetchFromGitHub {
owner = "Dao-AILab";
repo = "flash-attention";
rev = "v${version}";
fetchSubmodules = true;
hash = "sha256-ht234geMnOH0xKjhBOCXrzwYZuBFPvJMCZ9P8Vlpxcs=";
};
sourceRoot = "${src.name}/csrc/layer_norm";
stdenv = cudaPackages.backendStdenv;
buildInputs = with cudaPackages; [
cuda_cccl
cuda_cudart
libcublas
libcurand
libcusolver
libcusparse
psutil
];
nativeBuildInputs = [
autoAddDriverRunpath
cmake
git
ninja
packaging
which
];
env = {
CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}";
FLASH_ATTENTION_FORCE_BUILD = "TRUE";
};
propagatedBuildInputs = [ torch ];
# cmake/ninja are used for parallel builds, but we don't want the
# cmake configure hook to kick in.
dontUseCmakeConfigure = true;
# We don't have any tests in this package (yet).
doCheck = false;
preBuild = ''
export MAX_JOBS=$NIX_BUILD_CORES
'';
pythonImportsCheck = [ "dropout_layer_norm" ];
meta = with lib; {
description = "Marlin quantization kernels";
license = licenses.asl20;
platforms = platforms.linux;
};
}