in optimum/quanto/library/qbytes_mm.py [0:0]
def qbytes_mm_impl_cuda(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:
assert activations.ndim in (2, 3)
in_features = activations.shape[-1]
tokens = activations.shape[0] if activations.ndim == 2 else activations.shape[0] * activations.shape[1]
out_features = weights.shape[0]
if (
activations.dtype == torch.int8
and weights.dtype == torch.int8
and tokens > 16
and tokens % 8 == 0
and in_features % 8 == 0
and out_features % 8 == 0
):
return qbytes_int_mm(activations, weights, output_scales)
return qbytes_mm(activations, weights, output_scales)