XlaOpVector Scalar::Lower()

in Sources/x10/xla_tensor/ops/scalar.cpp [51:107]


XlaOpVector Scalar::Lower(LoweringContext* loctx) const {
  xla::Literal literal(xla::ShapeUtil::MakeShape(shape().element_type(), {}));
  switch (shape().element_type()) {
    case xla::PrimitiveType::PRED:
      literal.Set<bool>({}, static_cast<bool>(value_.toInt()));
      break;
    case xla::PrimitiveType::S8:
      literal.Set<xla::int8>({}, static_cast<xla::int8>(value_.toChar()));
      break;
    case xla::PrimitiveType::U8:
      literal.Set<xla::uint8>({}, static_cast<xla::uint8>(value_.toByte()));
      break;
    case xla::PrimitiveType::S16:
      literal.Set<xla::int16>({}, static_cast<xla::int16>(value_.toShort()));
      break;
    case xla::PrimitiveType::U16:
      literal.Set<xla::uint16>({}, static_cast<xla::uint16>(value_.toShort()));
      break;
    case xla::PrimitiveType::S32:
      literal.Set<xla::int32>({}, static_cast<xla::int32>(value_.toInt()));
      break;
    case xla::PrimitiveType::U32:
      literal.Set<xla::uint32>({}, static_cast<xla::uint32>(value_.toInt()));
      break;
    case xla::PrimitiveType::S64:
      literal.Set<xla::int64>({}, static_cast<xla::int64>(value_.toLong()));
      break;
    case xla::PrimitiveType::U64:
      literal.Set<xla::uint64>({}, static_cast<xla::uint64>(value_.toLong()));
      break;
    case xla::PrimitiveType::F32:
      literal.Set<float>({}, static_cast<float>(value_.toDouble()));
      break;
    case xla::PrimitiveType::F64:
      literal.Set<double>({}, value_.toDouble());
      break;
    case xla::PrimitiveType::BF16:
      literal.Set<xla::bfloat16>({},
                                 static_cast<xla::bfloat16>(value_.toDouble()));
      break;
    case xla::PrimitiveType::F16:
      literal.Set<xla::half>({}, static_cast<xla::half>(value_.toDouble()));
      break;
    default: {
      std::stringstream ss;
      ss << value_;
      XLA_ERROR() << "Unable to lower scalar " << ss.str() << " of shape "
                  << shape();
    }
  }

  xla::XlaOp op = xla::ConstantLiteral(loctx->builder(), literal);
  if (shape().rank() > 0) {
    op = xla::Broadcast(op, shape().dimensions());
  }
  return ReturnOp(op, loctx);
}