in lib/core_runtime/op_attrs.cc [745:828]
static void PrintElement(const void *ptr, OpAttrType type, raw_ostream &os) {
switch (type) {
case OpAttrType::DTYPE: {
auto dtype = *static_cast<const OpAttrType *>(ptr);
assert(dtype != OpAttrType::DTYPE);
os << GetNameString(dtype);
break;
}
case OpAttrType::AGGREGATE: {
AggregateAttr aggregate_attr(ptr);
size_t num_elements = aggregate_attr.GetNumElements();
os << "elt_count=" << num_elements << " [";
for (int i = 0; i < num_elements; ++i) {
auto base = aggregate_attr.GetAttribute(i);
os << "{";
if (IsDenseAttribute(base.type()) ||
base.type() == BEFAttributeType::kAggregate) {
PrintElement(base.data(),
GetOpAttrTypeFromBEFAttributeType(base.type()), os);
} else {
// TODO(chky): Support other types.
os << "unknown";
}
os << "}";
if (i < num_elements - 1) os << ", ";
}
os << "]";
break;
}
case OpAttrType::DENSE: {
DenseAttr dense_attr(ptr);
os << "dtype="
<< GetNameString(GetOpAttrTypeFromBEFAttributeType(
static_cast<BEFAttributeType>(dense_attr.dtype())))
<< ", rank=" << dense_attr.shape().size()
<< ", elt_count=" << dense_attr.GetNumElements();
break;
}
case OpAttrType::SHAPE: {
ShapeAttr shape_attr(ptr);
os << "<";
if (shape_attr.HasRank())
llvm::interleave(shape_attr.GetShape(), os, "x");
else
os << "*";
os << ">";
break;
}
case OpAttrType::FUNC:
os << GetNameString(OpAttrType::FUNC);
os << " function_name: " << *static_cast<const char *>(ptr);
break;
case OpAttrType::BF16:
assert(0 && "cannot print bf16 yet.");
break;
case OpAttrType::F16:
assert(0 && "cannot print fp16 yet.");
break;
case OpAttrType::I1:
os << *static_cast<const uint8_t *>(ptr);
break;
case OpAttrType::COMPLEX64:
os << "(" << static_cast<const std::complex<float> *>(ptr)->real() << ","
<< static_cast<const std::complex<float> *>(ptr)->imag() << ")";
break;
case OpAttrType::COMPLEX128:
os << "(" << static_cast<const std::complex<double> *>(ptr)->real() << ","
<< static_cast<const std::complex<double> *>(ptr)->imag() << ")";
break;
case OpAttrType::UNSUPPORTED_RESOURCE:
case OpAttrType::UNSUPPORTED_VARIANT:
case OpAttrType::UNSUPPORTED_QUI8:
case OpAttrType::UNSUPPORTED_QUI16:
case OpAttrType::UNSUPPORTED_QI8:
case OpAttrType::UNSUPPORTED_QI16:
case OpAttrType::UNSUPPORTED_QI32:
llvm_unreachable("unsupported attribute type");
#define OP_ATTR_TYPE(ENUM, CPP_TYPE) \
case OpAttrType::ENUM: \
os << *static_cast<const CPP_TYPE *>(ptr); \
break;
#include "tfrt/core_runtime/op_attr_type.def"
}
}