flake.nix (107 lines of code) (raw):

{ description = "Hugging Face Nix overlay"; inputs = { flake-utils.url = "github:numtide/flake-utils"; nixpkgs.url = "github:danieldk/nixpkgs/cudatoolkit-12.9-kernel-builder"; flake-compat.url = "github:edolstra/flake-compat"; }; outputs = { self, flake-compat, flake-utils, nixpkgs, }: let isCudaSystem = system: system == "x86_64-linux" || system == "aarch64-linux"; cudaConfig = { allowUnfree = true; cudaSupport = true; cudaCapabilities = [ "7.5" "8.0" "8.6" "8.9" "9.0" "9.0a" ]; }; rocmConfig = { allowUnfree = true; rocmSupport = true; }; overlay = import ./overlay.nix; in flake-utils.lib.eachSystem (with flake-utils.lib.system; [ aarch64-darwin x86_64-linux ]) ( system: let pkgsCuda = import nixpkgs { inherit system; config = cudaConfig; overlays = [ overlay ]; }; pkgsRocm = import nixpkgs { inherit system; config = rocmConfig; overlays = [ overlay ]; }; pkgsGeneric = import nixpkgs { inherit system; overlays = [ overlay ]; }; pkgs = if isCudaSystem system then pkgsCuda else pkgsGeneric; inherit (pkgs) lib; in rec { formatter = pkgs.nixfmt-tree; packages = rec { all = pkgs.symlinkJoin { name = "all"; paths = builtins.filter (lib.meta.availableOn { inherit system; }) (lib.attrValues python3Packages); }; lib = pkgs.lib; python3Packages = with pkgs.python3.pkgs; { inherit awq-inference-engine causal-conv1d compressed-tensors exllamav2 flash-attn flash-attn-layer-norm flash-attn-rotary flash-attn-v1 flashinfer hf-transfer hf-xet kernels mamba-ssm moe opentelemetry-instrumentation-grpc outlines paged-attention punica-sgmv quantization quantization-eetq rotary torch ; }; rocm = { python3Packages = with pkgsRocm.python3.pkgs; { inherit torch; }; }; }; } ) // { # Cheating a bit to conform to the schema. lib.config = system: if isCudaSystem system then cudaConfig else { }; overlays.default = overlay; }; }