BufferPtr CudaDevice::quantize()

in maga_transformer/cpp/devices/cuda_impl/CudaQuantizeOp.cc [33:202]


BufferPtr CudaDevice::quantize(const QuantizeParams& params) {
    RTP_LLM_CHECK_WITH_INFO((params.input.type() == DataType::TYPE_FP16 ||
                        params.input.type() == DataType::TYPE_FP32 ||
                        params.input.type() == DataType::TYPE_BF16),
        "cuda quantize only support half or float quantize. but get %d.", params.input.type());

    RTP_LLM_CHECK_WITH_INFO((params.input.dim() == 2),
        "cuda quantize only support 2D or 3D input.");

    RTP_LLM_CHECK_WITH_INFO((params.axis == (params.input.dim() - 1)),
		       "cuda quantize only support last axis:%d:%d.", params.axis, params.input.dim());

    BufferPtr scales = nullptr;
    DataType out_data_type = (params.qscheme == QScheme::Qfp8PerTensor || params.qscheme == QScheme::Qfp8PerTokenBlock) ? DataType::TYPE_FP8_E4M3 : DataType::TYPE_INT8;

    vector<size_t> input_shape = params.input.shape();
    if (params.qscheme == QScheme::Qfp8PerTokenBlock && params.paddingSize) {
        // padding to 64
        input_shape[0] = (input_shape[0] + params.paddingSize - 1) / params.paddingSize * params.paddingSize;
    }

    auto kernel = allocateBuffer({out_data_type,
                                 input_shape,
                                 getMemAllocationType(params.input.where())},
                                 {"kernel"});

    if (params.qscheme == QScheme::Qint8WeightOnly) {
        RTP_LLM_CHECK_WITH_INFO(params.input.where() == MemoryType::MEMORY_CPU, "cpu quantize");
        size_t axis = params.input.dim() - 1;
        scales = allocateBuffer({DataType::TYPE_FP16,
                                {input_shape[axis]},
                                getMemAllocationType(params.input.where())},
                                {"scales"});
        // TODO(lidongjin) The dispatch maro only support multi template type but without data cast,
        // or one template type with data cast, here need multi template type with data cast.
        if (params.input.type() == DataType::TYPE_FP16) {
            trt::symmetric_quantize(kernel->data<int8_t>(),
                                    nullptr,
                                    scales->data<half>(),
                                    params.input.data<half>(),
                                    input_shape,
                                    trtQuantTypeConvert(params.qtype),
                                    get_sm());
        } else if (params.input.type() == DataType::TYPE_BF16) {
            trt::symmetric_quantize(kernel->data<int8_t>(),
                                    nullptr,
                                    scales->data<half>(),
                                    params.input.data<__nv_bfloat16>(),
                                    input_shape,
                                    trtQuantTypeConvert(params.qtype),
                                    get_sm());
        } else if (params.input.type() == DataType::TYPE_FP32) {
            trt::symmetric_quantize(kernel->data<int8_t>(),
                                    nullptr,
                                    scales->data<half>(),
                                    params.input.data<float>(),
                                    input_shape,
                                    trtQuantTypeConvert(params.qtype),
                                    get_sm());
        } else {
            RTP_LLM_CHECK_WITH_INFO(false,
                "ERROR data type [%d] for cuda quantize input.", params.input.type());
        }
    } else if (params.qscheme == QScheme::Qint8PerToken) {
        scales = allocateBuffer({DataType::TYPE_FP32,
                                {input_shape[0]},
                                getMemAllocationType(params.input.where())},
                                {"scales"});
        DISPATCH_CUDA_FUNCTION_DATA_TYPE(params.input.type(), invokePerTokenQuantization,
                                         kernel->data<int8_t>(),
                                         params.input.data(),
                                         input_shape[0],
                                         input_shape[1],
                                         scales->data<float>(),
                                         params.smoother.has_value() ? params.smoother.value().get().data<float>() : nullptr,
                                         params.shift.has_value() ? params.shift.value().get().data<float>() : nullptr,
                                         stream_);

    } else if (params.qscheme == QScheme::Qint8PerTensor) {
        RTP_LLM_CHECK_WITH_INFO(params.static_scale_reciprocal.has_value(),
            "static_scale_reciprocal should not be nullptr in Qint8PerTensor");
        scales = BufferPtr(new Buffer(params.static_scale_reciprocal.value().get().where(),
                                      params.static_scale_reciprocal.value().get().type(),
                                      params.static_scale_reciprocal.value().get().shape(),
                                      params.static_scale_reciprocal.value().get().data()));

        DISPATCH_CUDA_FUNCTION_DATA_TYPE(params.input.type(), invokeQuantization,
                                         kernel->data<int8_t>(),
                                         params.input.data(),
                                         params.input.size(),
                                         params.static_scale.value().get().data<float>(),
                                         stream_,
                                         -1);
#ifdef ENABLE_FP8
    } else if (params.qscheme == QScheme::Qfp8PerTensor) {
        RTP_LLM_CHECK_WITH_INFO(params.static_scale_reciprocal.has_value(),
            "static_scale_reciprocal should not be nullptr in Qint8PerTensor");
        scales = BufferPtr(new Buffer(params.static_scale_reciprocal.value().get().where(),
                                    params.static_scale_reciprocal.value().get().type(),
                                    params.static_scale_reciprocal.value().get().shape(),
                                    params.static_scale_reciprocal.value().get().data()));
        switch (params.input.type()) {
            case DataType::TYPE_FP32:
                trt_common::invokeQuantizeMatrix( kernel->data<__nv_fp8_e4m3>(),
                            params.static_scale.value().get().data<float>(),
                            params.input.data<float>(),
                            params.input.size(),
                            input_shape[0],
                            trt_common::QuantizeMode::PER_TENSOR,
                            stream_);
                break;
            case DataType::TYPE_FP16:
                trt_common::invokeQuantizeMatrix(kernel->data<__nv_fp8_e4m3>(),
                            params.static_scale.value().get().data<float>(),
                            params.input.data<half>(),
                            params.input.size(),
                            input_shape[0],
                            trt_common::QuantizeMode::PER_TENSOR,
                            stream_);
                break;
#ifdef ENABLE_BF16
            case DataType::TYPE_BF16:
                trt_common::invokeQuantizeMatrix(kernel->data<__nv_fp8_e4m3>(),
                        params.static_scale.value().get().data<float>(),
                        params.input.data<__nv_bfloat16>(),
                        params.input.size(),
                        input_shape[0],
                        trt_common::QuantizeMode::PER_TENSOR,
                        stream_);
                break;
#endif
            default:
                RTP_LLM_CHECK_WITH_INFO(false, "unsupport data type");
        }
#ifdef ENABLE_BF16
    } else if (params.qscheme == QScheme::Qfp8PerTokenBlock) {
        RTP_LLM_CHECK_WITH_INFO(input_shape[1] % 128 == 0, "last dim must be divisible by 128");
        auto scales_shape = params.paddingSize? vector<size_t>({(unsigned int)(input_shape[1] / 128), input_shape[0]}): 
                                                vector<size_t>({input_shape[0], (unsigned int)(input_shape[1] / 128)});
        scales = allocateBuffer({DataType::TYPE_FP32,
                                scales_shape,
                                getMemAllocationType(params.input.where())},
                                {"scales"});
        if (input_shape[0] == 0) {
            return BufferPtr(new QBuffer(std::move(kernel),
                                         std::move(scales),
                                         std::move(BufferPtr(new Buffer(params.input.where(),
                                                                        DataType::TYPE_INVALID,
                                                                        {0},
                                                                        nullptr)))));
        }
        tensorrt_llm::common::invokeComputeFP8Quantize128(kernel->data<__nv_fp8_e4m3>(), scales->data<float>(), params.input.data<__nv_bfloat16>(), input_shape[0], input_shape[1], params.input.size(), params.paddingSize, stream_);

        
#endif
#endif
    } else {
        RTP_LLM_CHECK_WITH_INFO(false, "params qscheme type unknown: %d", int(params.qscheme));
    }

    sync_check_cuda_error();

    auto zeros_type = scales->where();
    return BufferPtr(new QBuffer(std::move(kernel),
                                 std::move(scales),
                                 std::move(BufferPtr(new Buffer(zeros_type,
                                                                DataType::TYPE_INVALID,
                                                                {0},
                                                                nullptr)))));
}