optimum/quanto/library/extensions/cuda/marlin/marlin_cuda.cpp (56 lines of code) (raw):

/* * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "marlin_cuda.h" #include <torch/all.h> #include <torch/python.h> #include <ATen/cuda/CUDAContext.h> #include <cuda_runtime.h> #include "marlin_cuda_kernel.cuh" const int ERR_PROB_SHAPE = 1; const int ERR_KERN_SHAPE = 2; void mul( const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, const torch::Tensor& s, const torch::Tensor& sz, // ADDED: add scaled zero point torch::Tensor& workspace, int thread_k, int thread_n, int sms, int max_par ) { int prob_m = A.size(0); int prob_n = C.size(1); int prob_k = A.size(1); int groupsize = (s.size(0) == 1) ? -1 : prob_k / s.size(0); if (groupsize != -1 && groupsize * s.size(0) != prob_k) AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); if (workspace.numel() < prob_n / 128 * max_par) AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, "."); int dev = A.get_device(); int err = marlin_cuda( A.data_ptr(), B.data_ptr(), C.data_ptr(), s.data_ptr(), sz.data_ptr(), // ADDED: add scaled zero point prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par ); if (err == ERR_PROB_SHAPE) { AT_ERROR( "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." ); } else if (err == ERR_KERN_SHAPE) { AT_ERROR( "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." ); } }