in bitsandbytes/nn/triton_based_modules.py [0:0]
def backward(ctx, G_3D):
X, W = ctx.save_for_backward
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
# rowwise quantize for G, columnwise quantize for W and fused transpose
# we call .t() for weight later because only A @ B^T is supported
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1],
-1,
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias