bitsandbytes/research/autograd/_functions.py [41:71]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        cB, state = F.quantize(B.float(), code=fw_code)
        fp8B = F.dequantize(cB, state).to(B.dtype)

        output = torch.matmul(fp8A, fp8B)

        # output is half

        # 3. Save state
        ctx.fw_code = fw_code
        ctx.bw_code = bw_code
        ctx.bsz = bsz
        ctx.bsz2 = bsz2
        ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype

        if any(ctx.needs_input_grad[:2]):
            # NOTE: we send back A, and re-quant.
            ctx.tensors = (A, fp8B)
        else:
            ctx.tensors = (None, None)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.is_empty:
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None

        req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
        A, B = ctx.tensors

        grad_A, grad_B = None, None
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



bitsandbytes/research/autograd/_functions.py [125:155]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        cB, state = F.quantize(B.float(), code=fw_code)
        fp8B = F.dequantize(cB, state).to(B.dtype)

        output = torch.matmul(fp8A, fp8B)

        # output is half

        # 3. Save state
        ctx.fw_code = fw_code
        ctx.bw_code = bw_code
        ctx.bsz = bsz
        ctx.bsz2 = bsz2
        ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype

        if any(ctx.needs_input_grad[:2]):
            # NOTE: we send back A, and re-quant.
            ctx.tensors = (A, fp8B)
        else:
            ctx.tensors = (None, None)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        if ctx.is_empty:
            return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None

        req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
        A, B = ctx.tensors

        grad_A, grad_B = None, None
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



