def _get_weights_configs()

in llm_perf/benchmark_runners/cuda/update_llm_perf_cuda_pytorch.py [0:0]


    def _get_weights_configs(self, subset) -> Dict[str, Dict[str, Any]]:
        if subset == "unquantized":
            return {
                "float32": {
                    "torch_dtype": "float32",
                    "quant_scheme": None,
                    "quant_config": {},
                },
                "float16": {
                    "torch_dtype": "float16",
                    "quant_scheme": None,
                    "quant_config": {},
                },
                "bfloat16": {
                    "torch_dtype": "bfloat16",
                    "quant_scheme": None,
                    "quant_config": {},
                },
            }
        elif subset == "bnb":
            return {
                "4bit-bnb": {
                    "torch_dtype": "float16",
                    "quant_scheme": "bnb",
                    "quant_config": {"load_in_4bit": True},
                },
                "8bit-bnb": {
                    "torch_dtype": "float16",
                    "quant_scheme": "bnb",
                    "quant_config": {"load_in_8bit": True},
                },
            }
        elif subset == "gptq":
            return {
                "4bit-gptq-exllama-v1": {
                    "torch_dtype": "float16",
                    "quant_scheme": "gptq",
                    "quant_config": {
                        "bits": 4,
                        "use_exllama ": True,
                        "version": 1,
                        "model_seqlen": 256,
                    },
                },
                "4bit-gptq-exllama-v2": {
                    "torch_dtype": "float16",
                    "quant_scheme": "gptq",
                    "quant_config": {
                        "bits": 4,
                        "use_exllama ": True,
                        "version": 2,
                        "model_seqlen": 256,
                    },
                },
            }
        elif subset == "awq":
            return {
                "4bit-awq-gemm": {
                    "torch_dtype": "float16",
                    "quant_scheme": "awq",
                    "quant_config": {"bits": 4, "version": "gemm"},
                },
                "4bit-awq-gemv": {
                    "torch_dtype": "float16",
                    "quant_scheme": "awq",
                    "quant_config": {"bits": 4, "version": "gemv"},
                },
                "4bit-awq-exllama-v1": {
                    "torch_dtype": "float16",
                    "quant_scheme": "awq",
                    "quant_config": {
                        "bits": 4,
                        "version": "exllama",
                        "exllama_config": {
                            "version": 1,
                            "max_input_len": 64,
                            "max_batch_size": 1,
                        },
                    },
                },
                "4bit-awq-exllama-v2": {
                    "torch_dtype": "float16",
                    "quant_scheme": "awq",
                    "quant_config": {
                        "bits": 4,
                        "version": "exllama",
                        "exllama_config": {
                            "version": 2,
                            "max_input_len": 64,
                            "max_batch_size": 1,
                        },
                    },
                },
            }
        elif subset == "torchao":
            return {
                "torchao-int4wo-128": {
                    "torch_dtype": "bfloat16",
                    "quant_scheme": "torchao",
                    "quant_config": {
                        "quant_type": "int4_weight_only",
                        "group_size": 128,
                    },
                },
            }
        else:
            raise ValueError(f"Unknown subset: {subset}")