overlay.nix (134 lines of code) (raw):
final: prev:
rec {
# Use MKL for BLAS/LAPACK on x86_64.
blas = if final.stdenv.isx86_64 then prev.blas.override { blasProvider = prev.mkl; } else prev.blas;
lapack =
if final.stdenv.isx86_64 then prev.lapack.override { lapackProvider = prev.mkl; } else prev.blas;
build2cmake = final.callPackage ./pkgs/build2cmake { };
fetchKernel = final.callPackage ./pkgs/fetch-kernel { };
magma-cuda-static = prev.magma-cuda-static.overrideAttrs (
_: prevAttrs: { buildInputs = prevAttrs.buildInputs ++ [ (prev.lib.getLib prev.gfortran.cc) ]; }
);
magma-hip =
(prev.callPackage ./pkgs/magma {
cudaSupport = false;
rocmSupport = true;
}).magma;
pythonPackagesExtensions = prev.pythonPackagesExtensions ++ [
(
python-self: python-super: with python-self; {
paged-attention = buildKernel rec {
pname = "paged-attention";
version = "0.0.3";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
inherit version;
hash = "sha256-Nalnx3kjQcKOa5YaBSKRImnVUHq2lko1TMfFJ7ocTb4=";
};
};
attention-kernels = callPackage ./pkgs/python-modules/attention-kernels { };
awq-inference-engine = callPackage ./pkgs/python-modules/awq-inference-engine { };
buildKernel = callPackage ./pkgs/python-modules/build-kernel { };
causal-conv1d = callPackage ./pkgs/python-modules/causal-conv1d { };
compressed-tensors = callPackage ./pkgs/python-modules/compressed-tensors { };
exllamav2 = callPackage ./pkgs/python-modules/exllamav2 { };
flash-attn = callPackage ./pkgs/python-modules/flash-attn { };
flash-attn-layer-norm = callPackage ./pkgs/python-modules/flash-attn-layer-norm { };
flash-attn-rotary = callPackage ./pkgs/python-modules/flash-attn-rotary { };
flash-attn-v1 = callPackage ./pkgs/python-modules/flash-attn-v1 { };
flashinfer = callPackage ./pkgs/python-modules/flashinfer { };
hf-transfer = callPackage ./pkgs/python-modules/hf-transfer { };
hf-xet = callPackage ./pkgs/python-modules/hf-xet { };
kernels = callPackage ./pkgs/python-modules/kernels { };
marlin-kernels = callPackage ./pkgs/python-modules/marlin-kernels { };
moe = buildKernel rec {
pname = "moe";
version = "0.3.0";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
inherit version;
hash = "sha256-CR0fgiooxt+pBE/VQETjNn7i0jSCD0g2NDC7KdNNIrc=";
};
};
moe-kernels = callPackage ./pkgs/python-modules/moe-kernels { };
#opentelemetry-proto = python-super.opentelemetry-proto.override { protobuf = super.protobuf3_24; };
opentelemetry-instrumentation-grpc = python-super.opentelemetry-instrumentation-grpc.overrideAttrs (
_: prevAttrs: {
patches = [ ];
# Overwrite old protobuf files which leads to failing.
preCheck = ''
python -m grpc_tools.protoc -Itests/protobuf --python_out=tests/protobuf \
--grpc_python_out=tests/protobuf tests/protobuf/test_server.proto # --mypy_out=text_generation_server/pb
'';
nativeBuildInputs = prevAttrs.nativeBuildInputs ++ [ python-super.grpcio-tools ];
}
);
mamba-ssm = callPackage ./pkgs/python-modules/mamba-ssm { };
punica-sgmv = buildKernel rec {
pname = "punica-sgmv";
version = "0.0.1";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
#inherit version;
rev = "5a84343633e93e2866f4e907dfc26b1ee07467ae";
hash = "sha256-z2em4jEZSgDfPX6s4jykVpuJOI1LRbI69Xq1T5lTM7s=";
};
cutlass = final.cutlass_3_6;
};
quantization = buildKernel rec {
pname = "quantization";
version = "0.0.4";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
inherit version;
hash = "sha256-qAMKM+2pKbYkJ9bHWlVijKcknrBjeFHLTXU2LCKA2dw=";
};
cutlass = final.cutlass_3_6;
};
quantization-eetq = buildKernel rec {
pname = "quantization-eetq";
version = "0.0.2";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
inherit version;
hash = "sha256-TiJAGEpZAR/UxovzsVhc5mto3FD4LA8urE04XA2D4KQ=";
};
cutlass = final.cutlass_2_10;
};
rocmPackages = final.rocmPackages_6_3;
rotary = buildKernel rec {
pname = "rotary";
version = "0.0.2";
src = fetchKernel {
repo_id = "kernels-community/${pname}";
inherit version;
hash = "sha256-D5/ErUNCQbNrbLGBNiucuWocyv+343W7tius6NcM9iQ=";
};
};
torch = python-self.torch_2_7;
torch_2_6 = callPackage ./pkgs/python-modules/torch_2_6 { rocmPackages = final.rocmPackages_6_2; };
torch_2_7 = callPackage ./pkgs/python-modules/torch_2_7 { rocmPackages = final.rocmPackages_6_3; };
}
)
];
}
// (import ./pkgs/cutlass { pkgs = final; })
// (
let
flattenVersion = prev.lib.strings.replaceStrings [ "." ] [ "_" ];
readPackageMetadata = path: (builtins.fromJSON (builtins.readFile path));
versions = [
"6.2.4"
"6.3.4"
];
newRocmPackages = final.callPackage ./pkgs/rocm-packages { };
in
builtins.listToAttrs (
map (version: {
name = "rocmPackages_${flattenVersion (prev.lib.versions.majorMinor version)}";
value = newRocmPackages {
packageMetadata = readPackageMetadata ./pkgs/rocm-packages/rocm-${version}-metadata.json;
};
}) versions
)
)