BEFAttributeType BefAttrEmitter::GetBefAttributeType()

in lib/bef_converter/mlir_to_bef/bef_attr_emitter.cc [136:267]


BEFAttributeType BefAttrEmitter::GetBefAttributeType(mlir::Attribute attr) {
  // We support 1-bit (stored as 1 byte in BEF), 32-bit, and 64-bit
  // integers.
  if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
    auto int_type = int_attr.getType().cast<mlir::IntegerType>();
    if (int_type.isUnsigned()) {
      switch (int_type.getWidth()) {
        case 8:
          return static_cast<BEFAttributeType>(DType::UI8);
        case 16:
          return static_cast<BEFAttributeType>(DType::UI16);
        case 32:
          return static_cast<BEFAttributeType>(DType::UI32);
        case 64:
          return static_cast<BEFAttributeType>(DType::UI64);
      }
    } else {
      switch (int_type.getWidth()) {
        case 1:
          return static_cast<BEFAttributeType>(DType::I1);
        case 8:
          return static_cast<BEFAttributeType>(DType::I8);
        case 16:
          return static_cast<BEFAttributeType>(DType::I16);
        case 32:
          return static_cast<BEFAttributeType>(DType::I32);
        case 64:
          return static_cast<BEFAttributeType>(DType::I64);
      }
    }
  }

  // We support BF16, F16, F32 and F64 floats.
  if (auto float_attr = attr.dyn_cast<mlir::FloatAttr>()) {
    if (float_attr.getType().isBF16())
      return static_cast<BEFAttributeType>(DType::BF16);
    if (float_attr.getType().isF16())
      return static_cast<BEFAttributeType>(DType::F16);
    if (float_attr.getType().isF32())
      return static_cast<BEFAttributeType>(DType::F32);
    if (float_attr.getType().isF64())
      return static_cast<BEFAttributeType>(DType::F64);
  }

  // We support string attributes.
  if (attr.isa<mlir::StringAttr>())
    return static_cast<BEFAttributeType>(DType::String);

  // We support i1, i8, i16, i32, i64, ui8, ui16, ui32, ui64, bf16, f16, f32,
  //  f64, quint8, quint16, qint8, qint16, qint32, complex64, complex128,
  //  string, resource and variant type attributes.
  if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
    auto type = type_attr.getValue();
    if (type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
        type.isInteger(32) || type.isInteger(64) || type.isBF16() ||
        type.isF16() || type.isF32() || type.isF64() ||
        type.isa<corert::StringType>() || type.isa<corert::ResourceType>() ||
        type.isa<corert::VariantType>() || type.isa<corert::Quint8Type>() ||
        type.isa<corert::Quint16Type>() || type.isa<corert::Qint8Type>() ||
        type.isa<corert::Qint16Type>() || type.isa<corert::Qint32Type>())
      return BEFAttributeType::kType;

    if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
      auto element_type = complex_type.getElementType();
      if (element_type.isF32() || element_type.isF64())
        return BEFAttributeType::kType;
    }
  }

  // We support corert.shape attributes
  if (attr.isa<tfrt::corert::ShapeAttr>()) {
    return BEFAttributeType::kShape;
  }

  // We support dense attributes.
  if (auto dense_elements_attr = attr.dyn_cast<mlir::DenseElementsAttr>()) {
    auto element_type =
        ConvertMlirTypeToDType(dense_elements_attr.getType().getElementType());
    // We only support dense attributes with dtype element type. The exception
    // is that we don't support string dtype, because strings have variable
    // size.
    //
    // TODO(tfrt-devs): Consider supporting string elements in the dense
    // attribute.
    if (element_type == DType::UI8 || element_type == DType::UI16 ||
        element_type == DType::UI32 || element_type == DType::UI64 ||
        element_type == DType::I1 || element_type == DType::I8 ||
        element_type == DType::I16 || element_type == DType::I32 ||
        element_type == DType::I64 || element_type == DType::BF16 ||
        element_type == DType::F16 || element_type == DType::F32 ||
        element_type == DType::F64 || element_type == DType::Complex64 ||
        element_type == DType::Complex128)
      return BEFAttributeType::kDense;

    return BEFAttributeType::kUnsupported;
  }

  // We support arrays of supported attribute values.
  if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
    if (array_attr.empty()) {
      return BEFAttributeType::kEmptyArray;
    }

    auto first_attr_type = GetBefAttributeType(*array_attr.begin());

    // Only fixed attributes can be included in an array.
    bool is_array = IsFixedAttribute(first_attr_type);

    for (auto elt : array_attr) {
      auto attr_type = GetBefAttributeType(elt);
      if (attr_type == BEFAttributeType::kUnsupported)
        return BEFAttributeType::kUnsupported;

      // Arrays requires all elements have the same type and the size.
      if (attr_type != first_attr_type) {
        is_array = false;
        break;
      }
    }

    if (is_array) return GetArrayAttributeType(first_attr_type);

    return BEFAttributeType::kAggregate;
  }

  // We support symbol references to compiled functions.
  if (auto symbol_ref_attr = attr.dyn_cast<mlir::SymbolRefAttr>()) {
    return BEFAttributeType::kSymbolRef;
  }

  return BEFAttributeType::kUnsupported;
}