bitsandbytes/triton/int8_matmul_mixed_dequantize.py [133:163]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        x_factor = tl.load(state_x_ptr + ram)[:, None]

        # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
        for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
            if EVEN_K:
                a = tl.load(A)
                b = tl.load(B)
            else:
                k_remaining = K - k * (BLOCK_K * SPLIT_K)
                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
            acc += tl.dot(a, b)
            A += BLOCK_K * SPLIT_K * stride_ak
            B += BLOCK_K * SPLIT_K * stride_bk

        acc = w_factor * (x_factor * (acc * divfactor))
        acc = acc.to(C.dtype.element_ty)

        # conditionally add bias
        if has_bias:
            bias = tl.load(bias + rn).to(C.dtype.element_ty)
            acc = acc + bias[None, :]

        C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        # handles write-back with reduction-splitting
        if SPLIT_K == 1:
            tl.store(C, acc, mask=mask)
        else:
            tl.atomic_add(C, acc, mask=mask)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



bitsandbytes/triton/int8_matmul_rowwise_dequantize.py [133:162]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        x_factor = tl.load(state_x_ptr + ram)[:, None]

        # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
        for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
            if EVEN_K:
                a = tl.load(A)
                b = tl.load(B)
            else:
                k_remaining = K - k * (BLOCK_K * SPLIT_K)
                a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)
                b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)
            acc += tl.dot(a, b)
            A += BLOCK_K * SPLIT_K * stride_ak
            B += BLOCK_K * SPLIT_K * stride_bk

        acc = w_factor * (x_factor * (acc * divfactor))
        acc = acc.to(C.dtype.element_ty)

        if has_bias:
            bias = tl.load(bias + rn).to(C.dtype.element_ty)
            acc = acc + bias[None, :]

        C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
        mask = (rm < M)[:, None] & (rn < N)[None, :]
        # handles write-back with reduction-splitting
        if SPLIT_K == 1:
            tl.store(C, acc, mask=mask)
        else:
            tl.atomic_add(C, acc, mask=mask)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



