in Sources/x10/xla_tensor/ops/scalar.cpp [51:107]
XlaOpVector Scalar::Lower(LoweringContext* loctx) const {
xla::Literal literal(xla::ShapeUtil::MakeShape(shape().element_type(), {}));
switch (shape().element_type()) {
case xla::PrimitiveType::PRED:
literal.Set<bool>({}, static_cast<bool>(value_.toInt()));
break;
case xla::PrimitiveType::S8:
literal.Set<xla::int8>({}, static_cast<xla::int8>(value_.toChar()));
break;
case xla::PrimitiveType::U8:
literal.Set<xla::uint8>({}, static_cast<xla::uint8>(value_.toByte()));
break;
case xla::PrimitiveType::S16:
literal.Set<xla::int16>({}, static_cast<xla::int16>(value_.toShort()));
break;
case xla::PrimitiveType::U16:
literal.Set<xla::uint16>({}, static_cast<xla::uint16>(value_.toShort()));
break;
case xla::PrimitiveType::S32:
literal.Set<xla::int32>({}, static_cast<xla::int32>(value_.toInt()));
break;
case xla::PrimitiveType::U32:
literal.Set<xla::uint32>({}, static_cast<xla::uint32>(value_.toInt()));
break;
case xla::PrimitiveType::S64:
literal.Set<xla::int64>({}, static_cast<xla::int64>(value_.toLong()));
break;
case xla::PrimitiveType::U64:
literal.Set<xla::uint64>({}, static_cast<xla::uint64>(value_.toLong()));
break;
case xla::PrimitiveType::F32:
literal.Set<float>({}, static_cast<float>(value_.toDouble()));
break;
case xla::PrimitiveType::F64:
literal.Set<double>({}, value_.toDouble());
break;
case xla::PrimitiveType::BF16:
literal.Set<xla::bfloat16>({},
static_cast<xla::bfloat16>(value_.toDouble()));
break;
case xla::PrimitiveType::F16:
literal.Set<xla::half>({}, static_cast<xla::half>(value_.toDouble()));
break;
default: {
std::stringstream ss;
ss << value_;
XLA_ERROR() << "Unable to lower scalar " << ss.str() << " of shape "
<< shape();
}
}
xla::XlaOp op = xla::ConstantLiteral(loctx->builder(), literal);
if (shape().rank() > 0) {
op = xla::Broadcast(op, shape().dimensions());
}
return ReturnOp(op, loctx);
}