def int8_matmul_rowwise_dequantize()

in bitsandbytes/triton/int8_matmul_rowwise_dequantize.py [0:0]


    def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
        divfactor = 1.0 / (127.0 * 127.0)

        has_bias = 0 if bias is None else 1

        device = a.device
        # handle non-contiguous inputs if necessary
        if a.stride(0) > 1 and a.stride(1) > 1:
            a = a.contiguous()
        if b.stride(0) > 1 and b.stride(1) > 1:
            b = b.contiguous()
        # checks constraints
        assert a.shape[1] == b.shape[0], "incompatible dimensions"
        M, K = a.shape
        _, N = b.shape
        # allocates output
        c = torch.empty((M, N), device=device, dtype=torch.float16)
        # accumulator types
        ACC_TYPE = tl.float32  # if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
        # launch int8_matmul_rowwise_dequantize kernel
        grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), META["SPLIT_K"])
        _int8_matmul_rowwise_dequantize[grid](
            a,
            b,
            c,
            bias,
            state_x,
            state_w,
            M,
            N,
            K,
            divfactor,
            has_bias,
            a.stride(0),
            a.stride(1),
            b.stride(0),
            b.stride(1),
            c.stride(0),
            c.stride(1),
            GROUP_M=8,
            ACC_TYPE=ACC_TYPE,
        )
        return c