static void PrintElement()

in lib/core_runtime/op_attrs.cc [745:828]


static void PrintElement(const void *ptr, OpAttrType type, raw_ostream &os) {
  switch (type) {
    case OpAttrType::DTYPE: {
      auto dtype = *static_cast<const OpAttrType *>(ptr);
      assert(dtype != OpAttrType::DTYPE);
      os << GetNameString(dtype);
      break;
    }
    case OpAttrType::AGGREGATE: {
      AggregateAttr aggregate_attr(ptr);
      size_t num_elements = aggregate_attr.GetNumElements();
      os << "elt_count=" << num_elements << " [";
      for (int i = 0; i < num_elements; ++i) {
        auto base = aggregate_attr.GetAttribute(i);
        os << "{";
        if (IsDenseAttribute(base.type()) ||
            base.type() == BEFAttributeType::kAggregate) {
          PrintElement(base.data(),
                       GetOpAttrTypeFromBEFAttributeType(base.type()), os);
        } else {
          // TODO(chky): Support other types.
          os << "unknown";
        }
        os << "}";
        if (i < num_elements - 1) os << ", ";
      }
      os << "]";
      break;
    }
    case OpAttrType::DENSE: {
      DenseAttr dense_attr(ptr);
      os << "dtype="
         << GetNameString(GetOpAttrTypeFromBEFAttributeType(
                static_cast<BEFAttributeType>(dense_attr.dtype())))
         << ", rank=" << dense_attr.shape().size()
         << ", elt_count=" << dense_attr.GetNumElements();
      break;
    }
    case OpAttrType::SHAPE: {
      ShapeAttr shape_attr(ptr);
      os << "<";
      if (shape_attr.HasRank())
        llvm::interleave(shape_attr.GetShape(), os, "x");
      else
        os << "*";
      os << ">";
      break;
    }
    case OpAttrType::FUNC:
      os << GetNameString(OpAttrType::FUNC);
      os << " function_name: " << *static_cast<const char *>(ptr);
      break;
    case OpAttrType::BF16:
      assert(0 && "cannot print bf16 yet.");
      break;
    case OpAttrType::F16:
      assert(0 && "cannot print fp16 yet.");
      break;
    case OpAttrType::I1:
      os << *static_cast<const uint8_t *>(ptr);
      break;
    case OpAttrType::COMPLEX64:
      os << "(" << static_cast<const std::complex<float> *>(ptr)->real() << ","
         << static_cast<const std::complex<float> *>(ptr)->imag() << ")";
      break;
    case OpAttrType::COMPLEX128:
      os << "(" << static_cast<const std::complex<double> *>(ptr)->real() << ","
         << static_cast<const std::complex<double> *>(ptr)->imag() << ")";
      break;
    case OpAttrType::UNSUPPORTED_RESOURCE:
    case OpAttrType::UNSUPPORTED_VARIANT:
    case OpAttrType::UNSUPPORTED_QUI8:
    case OpAttrType::UNSUPPORTED_QUI16:
    case OpAttrType::UNSUPPORTED_QI8:
    case OpAttrType::UNSUPPORTED_QI16:
    case OpAttrType::UNSUPPORTED_QI32:
      llvm_unreachable("unsupported attribute type");
#define OP_ATTR_TYPE(ENUM, CPP_TYPE)           \
  case OpAttrType::ENUM:                       \
    os << *static_cast<const CPP_TYPE *>(ptr); \
    break;
#include "tfrt/core_runtime/op_attr_type.def"
  }
}