in lib/bef_converter/mlir_to_bef/bef_attr_emitter.cc [136:267]
BEFAttributeType BefAttrEmitter::GetBefAttributeType(mlir::Attribute attr) {
// We support 1-bit (stored as 1 byte in BEF), 32-bit, and 64-bit
// integers.
if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
auto int_type = int_attr.getType().cast<mlir::IntegerType>();
if (int_type.isUnsigned()) {
switch (int_type.getWidth()) {
case 8:
return static_cast<BEFAttributeType>(DType::UI8);
case 16:
return static_cast<BEFAttributeType>(DType::UI16);
case 32:
return static_cast<BEFAttributeType>(DType::UI32);
case 64:
return static_cast<BEFAttributeType>(DType::UI64);
}
} else {
switch (int_type.getWidth()) {
case 1:
return static_cast<BEFAttributeType>(DType::I1);
case 8:
return static_cast<BEFAttributeType>(DType::I8);
case 16:
return static_cast<BEFAttributeType>(DType::I16);
case 32:
return static_cast<BEFAttributeType>(DType::I32);
case 64:
return static_cast<BEFAttributeType>(DType::I64);
}
}
}
// We support BF16, F16, F32 and F64 floats.
if (auto float_attr = attr.dyn_cast<mlir::FloatAttr>()) {
if (float_attr.getType().isBF16())
return static_cast<BEFAttributeType>(DType::BF16);
if (float_attr.getType().isF16())
return static_cast<BEFAttributeType>(DType::F16);
if (float_attr.getType().isF32())
return static_cast<BEFAttributeType>(DType::F32);
if (float_attr.getType().isF64())
return static_cast<BEFAttributeType>(DType::F64);
}
// We support string attributes.
if (attr.isa<mlir::StringAttr>())
return static_cast<BEFAttributeType>(DType::String);
// We support i1, i8, i16, i32, i64, ui8, ui16, ui32, ui64, bf16, f16, f32,
// f64, quint8, quint16, qint8, qint16, qint32, complex64, complex128,
// string, resource and variant type attributes.
if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
auto type = type_attr.getValue();
if (type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
type.isInteger(32) || type.isInteger(64) || type.isBF16() ||
type.isF16() || type.isF32() || type.isF64() ||
type.isa<corert::StringType>() || type.isa<corert::ResourceType>() ||
type.isa<corert::VariantType>() || type.isa<corert::Quint8Type>() ||
type.isa<corert::Quint16Type>() || type.isa<corert::Qint8Type>() ||
type.isa<corert::Qint16Type>() || type.isa<corert::Qint32Type>())
return BEFAttributeType::kType;
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
auto element_type = complex_type.getElementType();
if (element_type.isF32() || element_type.isF64())
return BEFAttributeType::kType;
}
}
// We support corert.shape attributes
if (attr.isa<tfrt::corert::ShapeAttr>()) {
return BEFAttributeType::kShape;
}
// We support dense attributes.
if (auto dense_elements_attr = attr.dyn_cast<mlir::DenseElementsAttr>()) {
auto element_type =
ConvertMlirTypeToDType(dense_elements_attr.getType().getElementType());
// We only support dense attributes with dtype element type. The exception
// is that we don't support string dtype, because strings have variable
// size.
//
// TODO(tfrt-devs): Consider supporting string elements in the dense
// attribute.
if (element_type == DType::UI8 || element_type == DType::UI16 ||
element_type == DType::UI32 || element_type == DType::UI64 ||
element_type == DType::I1 || element_type == DType::I8 ||
element_type == DType::I16 || element_type == DType::I32 ||
element_type == DType::I64 || element_type == DType::BF16 ||
element_type == DType::F16 || element_type == DType::F32 ||
element_type == DType::F64 || element_type == DType::Complex64 ||
element_type == DType::Complex128)
return BEFAttributeType::kDense;
return BEFAttributeType::kUnsupported;
}
// We support arrays of supported attribute values.
if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
if (array_attr.empty()) {
return BEFAttributeType::kEmptyArray;
}
auto first_attr_type = GetBefAttributeType(*array_attr.begin());
// Only fixed attributes can be included in an array.
bool is_array = IsFixedAttribute(first_attr_type);
for (auto elt : array_attr) {
auto attr_type = GetBefAttributeType(elt);
if (attr_type == BEFAttributeType::kUnsupported)
return BEFAttributeType::kUnsupported;
// Arrays requires all elements have the same type and the size.
if (attr_type != first_attr_type) {
is_array = false;
break;
}
}
if (is_array) return GetArrayAttributeType(first_attr_type);
return BEFAttributeType::kAggregate;
}
// We support symbol references to compiled functions.
if (auto symbol_ref_attr = attr.dyn_cast<mlir::SymbolRefAttr>()) {
return BEFAttributeType::kSymbolRef;
}
return BEFAttributeType::kUnsupported;
}