in optimum/habana/AutoAWQ/gemm_hpu.py [0:0]
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
nn.Module.__init__(self)
assert w_bit == 4, "Only 4 bit are supported for now."
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.scale_dtype = torch.float32
self.training = training
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.pack_num = 32 // self.w_bit
self.init_ipex = False
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // self.pack_num),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.bfloat16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros((out_features), dtype=torch.bfloat16, device=dev),
)
else:
self.bias = None
self.register_buffer(
"qweight",
torch.zeros((in_features, out_features // self.pack_num), dtype=torch.int32, device=dev),
)
self._preprocess = False