mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (142 lines of code) (raw):

//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// // ExecutionEngine/SparseTensorUtils helper functions. //===----------------------------------------------------------------------===// OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { switch (width) { case 64: return OverheadType::kU64; case 32: return OverheadType::kU32; case 16: return OverheadType::kU16; case 8: return OverheadType::kU8; case 0: return OverheadType::kIndex; } llvm_unreachable("Unsupported overhead bitwidth"); } OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { if (tp.isIndex()) return OverheadType::kIndex; if (auto intTp = tp.dyn_cast<IntegerType>()) return overheadTypeEncoding(intTp.getWidth()); llvm_unreachable("Unknown overhead type"); } Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { switch (ot) { case OverheadType::kIndex: return builder.getIndexType(); case OverheadType::kU64: return builder.getIntegerType(64); case OverheadType::kU32: return builder.getIntegerType(32); case OverheadType::kU16: return builder.getIntegerType(16); case OverheadType::kU8: return builder.getIntegerType(8); } llvm_unreachable("Unknown OverheadType"); } Type mlir::sparse_tensor::getPointerOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { return getOverheadType(builder, overheadTypeEncoding(enc.getPointerBitWidth())); } Type mlir::sparse_tensor::getIndexOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); } StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { switch (ot) { case OverheadType::kIndex: return ""; case OverheadType::kU64: return "64"; case OverheadType::kU32: return "32"; case OverheadType::kU16: return "16"; case OverheadType::kU8: return "8"; } llvm_unreachable("Unknown OverheadType"); } StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); } PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { if (elemTp.isF64()) return PrimaryType::kF64; if (elemTp.isF32()) return PrimaryType::kF32; if (elemTp.isInteger(64)) return PrimaryType::kI64; if (elemTp.isInteger(32)) return PrimaryType::kI32; if (elemTp.isInteger(16)) return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; llvm_unreachable("Unknown primary type"); } StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { switch (pt) { case PrimaryType::kF64: return "F64"; case PrimaryType::kF32: return "F32"; case PrimaryType::kI64: return "I64"; case PrimaryType::kI32: return "I32"; case PrimaryType::kI16: return "I16"; case PrimaryType::kI8: return "I8"; } llvm_unreachable("Unknown PrimaryType"); } StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); } DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { case SparseTensorEncodingAttr::DimLevelType::Dense: return DimLevelType::kDense; case SparseTensorEncodingAttr::DimLevelType::Compressed: return DimLevelType::kCompressed; case SparseTensorEncodingAttr::DimLevelType::Singleton: return DimLevelType::kSingleton; } llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); } //===----------------------------------------------------------------------===// // Misc code generators. //===----------------------------------------------------------------------===// mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { if (tp.isa<FloatType>()) return builder.getFloatAttr(tp, 1.0); if (tp.isa<IndexType>()) return builder.getIndexAttr(1); if (auto intTp = tp.dyn_cast<IntegerType>()) return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); if (tp.isa<RankedTensorType, VectorType>()) { auto shapedTp = tp.cast<ShapedType>(); if (auto one = getOneAttr(builder, shapedTp.getElementType())) return DenseElementsAttr::get(shapedTp, one); } llvm_unreachable("Unsupported attribute type"); } Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, Value v) { Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); if (tp.isa<FloatType>()) return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, zero); llvm_unreachable("Non-numeric type"); }