size_t BefAttrEmitter::EmitAttribute()

in lib/bef_converter/mlir_to_bef/bef_attr_emitter.cc [287:366]


size_t BefAttrEmitter::EmitAttribute(BEFAttributeType attribute_type,
                                     mlir::Attribute mlir_attr) {
  if (IsMatchedWithDType(attribute_type, DType::UI8) ||
      IsMatchedWithDType(attribute_type, DType::I1))
    return EmitIntegerAttribute<uint8_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::UI16))
    return EmitIntegerAttribute<uint16_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::UI32))
    return EmitIntegerAttribute<uint32_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::UI64))
    return EmitIntegerAttribute<uint64_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::I8))
    return EmitIntegerAttribute<int8_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::I16))
    return EmitIntegerAttribute<int16_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::I32))
    return EmitIntegerAttribute<int32_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::I64))
    return EmitIntegerAttribute<int64_t>(mlir_attr);

  if (IsMatchedWithDType(attribute_type, DType::F32)) {
    auto attr = mlir_attr.cast<mlir::FloatAttr>();
    return EncodeAttr<float>(
        static_cast<float>(attr.getValue().convertToFloat()));
  }

  if (IsMatchedWithDType(attribute_type, DType::F64)) {
    auto attr = mlir_attr.cast<mlir::FloatAttr>();
    return EncodeAttr<double>(
        static_cast<double>(attr.getValue().convertToDouble()));
  }

  if (IsMatchedWithDType(attribute_type, DType::F16) ||
      IsMatchedWithDType(attribute_type, DType::BF16)) {
    auto attr = mlir_attr.cast<mlir::FloatAttr>();
    return EncodeAttr<uint16_t>(static_cast<uint16_t>(
        attr.getValue().bitcastToAPInt().getLimitedValue()));
  }

  if (IsMatchedWithDType(attribute_type, DType::String)) {
    auto attr = mlir_attr.cast<mlir::StringAttr>();
    return EncodeStringAttr(attr.getValue());
  }

  if (attribute_type == BEFAttributeType::kType) {
    auto attr = mlir_attr.cast<mlir::TypeAttr>();
    const auto dtype = ConvertMlirTypeToDType(attr.getValue());
    return EncodeAttr<uint8_t>(static_cast<uint8_t>(dtype));
  }

  if (attribute_type == BEFAttributeType::kShape) {
    tfrt::corert::ShapeAttr shape_attr =
        mlir_attr.cast<tfrt::corert::ShapeAttr>();
    return (shape_attr.hasRank()) ? EncodeRankedShapeAttr(shape_attr.getShape())
                                  : EncodeUnrankedShapeAttr();
  }

  if (IsArrayAttribute(attribute_type)) {
    return EmitArrayAttribute(attribute_type,
                              mlir_attr.cast<mlir::ArrayAttr>());
  }

  if (IsDenseAttribute(attribute_type)) {
    return EmitDenseAttribute(attribute_type,
                              mlir_attr.cast<mlir::DenseElementsAttr>());
  }

  if (attribute_type == BEFAttributeType::kAggregate) {
    return EmitAggregatedAttribute(mlir_attr.cast<mlir::ArrayAttr>());
  }

  llvm_unreachable("Unknown attribute");
}