in bitsandbytes/autograd/_functions.py [0:0]
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
quant_type = ctx.quant_type
precision = ctx.precision
grad_A = grad_B = None
if B.requires_grad:
if len(A.shape) == 3:
dims = [0, 1]
# bsi -> ibs
permute_dim = [0, 2, 1]
else:
dims = [0]
# bs -> sb
permute_dim = [1, 0]
if precision[1] != 8:
with torch.no_grad():
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous():
grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(
grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)
if A.requires_grad:
if len(grad_output.shape) == 3:
dims = [2]
else:
dims = [1]
if len(B.shape) == 3:
# bio -> boi
permute_dim = [0, 2, 1]
dim_B = dims
else:
# io -> oi
permute_dim = [1, 0]
dim_B = [1]
if precision[2] != 8:
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
igrad_A,
S1,
S3.permute(permute_dim),
grad_output.dtype,
quant_type,
)
return grad_A, grad_B, None, None, None