optimum/habana/transformers/integrations/awq.py (157 lines of code) (raw):
import importlib
from enum import Enum
import torch.nn as nn
from packaging import version
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_accelerate_available, is_auto_awq_available
from transformers.utils.quantization_config import (
AwqBackendPackingMethod,
)
from optimum.utils import logging
logger = logging.get_logger(__name__)
class GaudiAWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"
HPU = "hpu"
@staticmethod
def from_str(version: str):
version = version.lower()
if version == "gemm":
return GaudiAWQLinearVersion.GEMM
elif version == "gemv":
return GaudiAWQLinearVersion.GEMV
elif version == "exllama":
return GaudiAWQLinearVersion.EXLLAMA
elif version == "hpu":
return GaudiAWQLinearVersion.HPU
else:
raise ValueError(f"Unknown GaudiAWQLinearVersion {version}")
# override post_init in AwqConfig
def gaudi_awq_config_post_init(self):
"""
Adapted from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/utils/quantization_config.py#L818
- support HPU.
"""
if self.backend not in [AwqBackendPackingMethod.AUTOAWQ]:
raise ValueError(
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} - not recognized backend {self.backend}"
)
self.version = GaudiAWQLinearVersion.from_str(self.version)
if self.version not in [
GaudiAWQLinearVersion.HPU,
GaudiAWQLinearVersion.GEMM,
]:
raise ValueError(
f"Only supported versions are in [GaudiAWQLinearVersion.HPU, GaudiAWQLinearVersion.GEMM] - not recognized version {self.version}"
)
if self.do_fuse and self.fuse_max_seq_len is None:
raise ValueError(
"You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase"
)
if self.do_fuse:
awq_version_supports_fusing = False
MIN_AWQ_VERSION = "0.1.7"
if is_auto_awq_available():
awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
MIN_AWQ_VERSION
)
if not awq_version_supports_fusing:
raise ValueError(
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.modules_to_not_convert is not None:
awq_version_supports_non_conversion = False
MIN_AWQ_VERSION = "0.1.8"
if is_auto_awq_available():
awq_version_supports_non_conversion = version.parse(
importlib.metadata.version("autoawq")
) >= version.parse(MIN_AWQ_VERSION)
if not awq_version_supports_non_conversion:
raise ValueError(
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.do_fuse and self.modules_to_fuse is not None:
raise ValueError("You current implementation of `autoawq` does not support do_fuse and modules_to_fuse.")
def gaudi_replace_with_awq_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
) -> bool:
"""
Adapted from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/integrations/awq.py#L90
- support HPU.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
assert quantization_config is not None
backend = quantization_config.backend
if not is_auto_awq_available():
raise ValueError(
"AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
)
if backend == AwqBackendPackingMethod.AUTOAWQ and quantization_config.version == GaudiAWQLinearVersion.HPU:
from ...AutoAWQ.gemm_hpu import WQLinear_HPU
target_cls = WQLinear_HPU
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version} and backend {backend}")
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = target_cls(
w_bit=quantization_config.bits,
group_size=quantization_config.group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = gaudi_replace_with_awq_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def post_init_awq_gemm_hpu_modules(model):
"""
Runs post init for gemm hpu layers which performs:
- Weights unpacking, reordering and repacking
"""
from ...AutoAWQ.gemm_hpu import hpu_post_init
model = hpu_post_init(model)
return model
def gaudi_awq_quantizer_process_model_after_weight_loading(self, model, **kwargs):
if self.quantization_config.version == GaudiAWQLinearVersion.HPU:
model = post_init_awq_gemm_hpu_modules(model)
else:
raise ValueError(f"Unrecognized AWQ version: {self.quantization_config.version}, only hpu is supported")
def gaudi_awq_quantizer_validate_environment(self, device_map, **kwargs):
if not is_auto_awq_available():
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
if not is_accelerate_available():
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
if device_map is None:
logger.warning_once(
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
def gaudi_awq_quantizer_process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from transformers.integrations import get_keys_to_not_convert, replace_quantization_scales
self.modules_to_not_convert = get_keys_to_not_convert(model)
if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
model, has_been_replaced = gaudi_replace_with_awq_linear(
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
)
model = replace_quantization_scales(model, model.config.model_type)
if not has_been_replaced:
logger.warning(
"You are loading an AWQ model but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is a bug."
)