in bitsandbytes/nn/modules.py [0:0]
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out