mlir/lib/CAPI/Dialect/Quant.cpp (158 lines of code) (raw):

//===- LLVM.cpp - C Interface for Quant dialect ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Quant.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Quant/QuantTypes.h" using namespace mlir; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) //===---------------------------------------------------------------------===// // QuantizedType //===---------------------------------------------------------------------===// bool mlirTypeIsAQuantizedType(MlirType type) { return unwrap(type).isa<quant::QuantizedType>(); } unsigned mlirQuantizedTypeGetSignedFlag() { return quant::QuantizationFlags::Signed; } int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, unsigned integralWidth) { return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, integralWidth); } int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, unsigned integralWidth) { return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, integralWidth); } MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { return wrap(unwrap(type).cast<quant::QuantizedType>().getExpressedType()); } unsigned mlirQuantizedTypeGetFlags(MlirType type) { return unwrap(type).cast<quant::QuantizedType>().getFlags(); } bool mlirQuantizedTypeIsSigned(MlirType type) { return unwrap(type).cast<quant::QuantizedType>().isSigned(); } MlirType mlirQuantizedTypeGetStorageType(MlirType type) { return wrap(unwrap(type).cast<quant::QuantizedType>().getStorageType()); } int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMin(); } int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { return unwrap(type).cast<quant::QuantizedType>().getStorageTypeMax(); } unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { return unwrap(type) .cast<quant::QuantizedType>() .getStorageTypeIntegralWidth(); } bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate) { return unwrap(type).cast<quant::QuantizedType>().isCompatibleExpressedType( unwrap(candidate)); } MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); } MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate) { return wrap(unwrap(type).cast<quant::QuantizedType>().castFromStorageType( unwrap(candidate))); } MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { return wrap(quant::QuantizedType::castToStorageType( unwrap(type).cast<quant::QuantizedType>())); } MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate) { return wrap(unwrap(type).cast<quant::QuantizedType>().castFromExpressedType( unwrap(candidate))); } MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); } MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate) { return wrap( unwrap(type).cast<quant::QuantizedType>().castExpressedToStorageType( unwrap(candidate))); } //===---------------------------------------------------------------------===// // AnyQuantizedType //===---------------------------------------------------------------------===// bool mlirTypeIsAAnyQuantizedType(MlirType type) { return unwrap(type).isa<quant::AnyQuantizedType>(); } MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, MlirType expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), unwrap(expressedType), storageTypeMin, storageTypeMax)); } //===---------------------------------------------------------------------===// // UniformQuantizedType //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedType(MlirType type) { return unwrap(type).isa<quant::UniformQuantizedType>(); } MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, MlirType expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { return wrap(quant::UniformQuantizedType::get( flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, storageTypeMin, storageTypeMax)); } double mlirUniformQuantizedTypeGetScale(MlirType type) { return unwrap(type).cast<quant::UniformQuantizedType>().getScale(); } int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { return unwrap(type).cast<quant::UniformQuantizedType>().getZeroPoint(); } bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { return unwrap(type).cast<quant::UniformQuantizedType>().isFixedPoint(); } //===---------------------------------------------------------------------===// // UniformQuantizedPerAxisType //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { return unwrap(type).isa<quant::UniformQuantizedPerAxisType>(); } MlirType mlirUniformQuantizedPerAxisTypeGet( unsigned flags, MlirType storageType, MlirType expressedType, intptr_t nDims, double *scales, int64_t *zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { return wrap(quant::UniformQuantizedPerAxisType::get( flags, unwrap(storageType), unwrap(expressedType), llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims), quantizedDimension, storageTypeMin, storageTypeMax)); } intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { return unwrap(type) .cast<quant::UniformQuantizedPerAxisType>() .getScales() .size(); } double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { return unwrap(type) .cast<quant::UniformQuantizedPerAxisType>() .getScales()[pos]; } int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos) { return unwrap(type) .cast<quant::UniformQuantizedPerAxisType>() .getZeroPoints()[pos]; } int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { return unwrap(type) .cast<quant::UniformQuantizedPerAxisType>() .getQuantizedDimension(); } bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { return unwrap(type).cast<quant::UniformQuantizedPerAxisType>().isFixedPoint(); } //===---------------------------------------------------------------------===// // CalibratedQuantizedType //===---------------------------------------------------------------------===// bool mlirTypeIsACalibratedQuantizedType(MlirType type) { return unwrap(type).isa<quant::CalibratedQuantizedType>(); } MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max) { return wrap( quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); } double mlirCalibratedQuantizedTypeGetMin(MlirType type) { return unwrap(type).cast<quant::CalibratedQuantizedType>().getMin(); } double mlirCalibratedQuantizedTypeGetMax(MlirType type) { return unwrap(type).cast<quant::CalibratedQuantizedType>().getMax(); }