mlir::Attribute BefAttrReader::ReadAttribute()

in lib/bef_converter/bef_to_mlir/bef_attr_reader.cc [87:199]


mlir::Attribute BefAttrReader::ReadAttribute(BEFAttributeType attribute_type,
                                             size_t offset) {
  if (IsArrayAttribute(attribute_type)) {
    return ReadArrayAttribute(attribute_type, offset);
  }

  if (IsDenseAttribute(attribute_type)) {
    return ReadDenseAttribute(offset);
  }

  if (attribute_type == BEFAttributeType::kAggregate) {
    return ReadAggregateAttribute(offset);
  }

  const auto ptr = &attributes_[offset];

  if (attribute_type == static_cast<BEFAttributeType>(DType::UI8)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 8, mlir::IntegerType::Unsigned),
        Attribute<uint8_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::UI16)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 16, mlir::IntegerType::Unsigned),
        Attribute<uint16_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::UI32)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 32, mlir::IntegerType::Unsigned),
        Attribute<uint32_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::UI64)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 64, mlir::IntegerType::Unsigned),
        Attribute<uint64_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::I1)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 1, mlir::IntegerType::Signless),
        Attribute<uint8_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::I8)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 8, mlir::IntegerType::Signless),
        Attribute<int8_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::I16)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 16, mlir::IntegerType::Signless),
        Attribute<int16_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::I32)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 32, mlir::IntegerType::Signless),
        Attribute<int32_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::I64)) {
    return mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 64, mlir::IntegerType::Signless),
        Attribute<int64_t>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::F32)) {
    return mlir::FloatAttr::get(builder_.getF32Type(),
                                Attribute<float>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::F64)) {
    return mlir::FloatAttr::get(builder_.getF64Type(),
                                Attribute<double>(ptr).get());
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::F16) ||
      attribute_type == static_cast<BEFAttributeType>(DType::BF16)) {
    auto ftype = (attribute_type == static_cast<BEFAttributeType>(DType::F16))
                     ? builder_.getF16Type()
                     : builder_.getBF16Type();

    auto int_attr = mlir::IntegerAttr::get(
        mlir::IntegerType::get(&context_, 16, mlir::IntegerType::Unsigned),
        Attribute<uint16_t>(ptr).get());

    return mlir::FloatAttr::get(
        ftype, llvm::APFloat(ftype.getFloatSemantics(), int_attr.getValue()));
  }

  if (attribute_type == static_cast<BEFAttributeType>(DType::String)) {
    return mlir::StringAttr::get(&context_, StringAttribute(ptr).get());
  }

  if (attribute_type == BEFAttributeType::kType) {
    return mlir::TypeAttr::get(DecodeTypeAttribute(
        &builder_, static_cast<DType>(Attribute<uint8_t>(ptr).get())));
  }

  if (attribute_type == BEFAttributeType::kShape) {
    auto shape = ShapeAttr(ptr);
    return shape.HasRank()
               ? tfrt::corert::ShapeAttr::get(builder_.getContext(),
                                              shape.GetShape())
               : tfrt::corert::ShapeAttr::get(builder_.getContext());
  }

  llvm_unreachable("Unknown attribute");
}