in optimum/quanto/library/qbytes_mm.py [0:0]
def qbytes_mm_impl_mps(activations: torch.Tensor, weights: torch.Tensor, output_scales: torch.Tensor) -> torch.Tensor:
in_features = activations.shape[-1]
out_features = weights.shape[0]
if (
version.parse(torch.__version__).release >= version.parse("2.4.0").release
and activations.dtype == torch.bfloat16
and weights.dtype == torch.int8
and in_features % 32 == 0
and out_features % 32 == 0
):
if type(activations) is not torch.Tensor:
activations = activations.dequantize()
return qbytes_int8pack_mm(activations, weights, output_scales)
return qbytes_mm(activations, weights, output_scales)