def qbytes_mm_impl_cuda()

in optimum/quanto/library/qbytes_mm.py [0:0]


def qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:
    assert activations.ndim in (2, 3)
    in_features = activations.shape[-1]
    tokens = activations.shape[0] if activations.ndim == 2 else activations.shape[0] * activations.shape[1]
    out_features = weights.shape[0]
    if (
        activations.dtype == torch.int8
        and weights.dtype == torch.int8
        and tokens > 16
        and tokens % 8 == 0
        and in_features % 8 == 0
        and out_features % 8 == 0
    ):
        return qbytes_int_mm(activations, weights, output_scales)
    return qbytes_mm(activations, weights, output_scales)