optimum/habana/AutoAWQ/gemm_hpu.py (123 lines of code) (raw):

import torch import torch.nn as nn from awq.modules.linear.gemm import WQLinear_GEMM from awq.utils.packing_utils import reverse_awq_order, unpack_awq try: import habana_frameworks.torch.hpu # noqa: F401 convert_from_uint4 = torch.ops.hpu.convert_from_uint4 except Exception as e: hpu_import_exception = e def error_raiser_hpu(*args, **kwargs): raise ValueError( f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}" ) convert_from_uint4 = error_raiser_hpu def unpack_weight_and_zeros(qweight, qzeros, bits): # Unpack the qweight and qzeros tensors iweight, izeros = unpack_awq(qweight, qzeros, bits) # Reverse the order of the iweight and izeros tensors iweight, izeros = reverse_awq_order(iweight, izeros, bits) # overflow checks iweight = torch.bitwise_and(iweight, (2**bits) - 1) izeros = torch.bitwise_and(izeros, (2**bits) - 1) return iweight, izeros def pack_tensor(input, bits=4): normal = input.to(torch.int32) q = torch.zeros( (normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32, device=input.device, ) i = 0 col = 0 while col < q.shape[1]: for j in range(i, i + (32 // bits)): q[:, col] |= normal[:, j] << (bits * (j - i)) i += 32 // bits col += 1 q = q.to(torch.int32) return q class WQLinear_HPU(WQLinear_GEMM): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): nn.Module.__init__(self) assert w_bit == 4, "Only 4 bit are supported for now." self.in_features = in_features self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features self.scale_dtype = torch.float32 self.training = training # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 self.pack_num = 32 // self.w_bit self.init_ipex = False self.register_buffer( "qzeros", torch.zeros( (in_features // self.group_size, out_features // self.pack_num), dtype=torch.int32, device=dev, ), ) self.register_buffer( "scales", torch.zeros( (in_features // self.group_size, out_features), dtype=torch.bfloat16, device=dev, ), ) if bias: self.register_buffer( "bias", torch.zeros((out_features), dtype=torch.bfloat16, device=dev), ) else: self.bias = None self.register_buffer( "qweight", torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev), ) self._preprocess = False def _preprocessing(self): device = self.qweight.device weight, zeros = unpack_weight_and_zeros(self.qweight.cpu(), self.qzeros.cpu(), self.w_bit) self.qweight = pack_tensor(weight).to(device) self.qzeros = pack_tensor(zeros).to(device) self._preprocess = True def post_init(self): self._preprocessing() @classmethod def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): awq_linear = cls( w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device, ) if init_only: # just prepare for loading sd return awq_linear raise NotImplementedError("Only inference is supported for HPU kernels") def forward(self, x): assert self._preprocess is True, ( "module.post_init() must be called before module.forward(). Use hpu_post_init() on the whole model." ) out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype) outputs = torch.matmul(x, weights) outputs = outputs + self.bias if self.bias is not None else outputs outputs = outputs.reshape(out_shape) return outputs def extra_repr(self) -> str: return "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size, ) def hpu_post_init(model): for _, submodule in model.named_modules(): if isinstance(submodule, WQLinear_HPU): submodule.post_init() return model