in torch_xla/csrc/helpers.h [41:79]
static xla::Literal ScalarLiteral(T scalar_value, xla::PrimitiveType type) {
switch (type) {
case xla::PrimitiveType::F64:
return xla::LiteralUtil::CreateR0<double>(scalar_value);
case xla::PrimitiveType::F32:
return xla::LiteralUtil::CreateR0<float>(scalar_value);
case xla::PrimitiveType::BF16:
return xla::LiteralUtil::CreateR0<tensorflow::bfloat16>(
static_cast<tensorflow::bfloat16>(
static_cast<float>(scalar_value)));
case xla::PrimitiveType::F16:
return xla::LiteralUtil::CreateR0<xla::half>(
static_cast<xla::half>(static_cast<float>(scalar_value)));
case xla::PrimitiveType::S64:
return xla::LiteralUtil::CreateR0<xla::int64_t>(scalar_value);
case xla::PrimitiveType::U64:
return xla::LiteralUtil::CreateR0<xla::uint64>(scalar_value);
case xla::PrimitiveType::S32:
return xla::LiteralUtil::CreateR0<xla::int32>(scalar_value);
case xla::PrimitiveType::U32:
return xla::LiteralUtil::CreateR0<xla::uint32>(scalar_value);
case xla::PrimitiveType::S16:
return xla::LiteralUtil::CreateR0<xla::int16>(scalar_value);
case xla::PrimitiveType::U16:
return xla::LiteralUtil::CreateR0<xla::uint16>(scalar_value);
case xla::PrimitiveType::S8:
return xla::LiteralUtil::CreateR0<xla::int8>(scalar_value);
case xla::PrimitiveType::U8:
return xla::LiteralUtil::CreateR0<xla::uint8>(scalar_value);
case xla::PrimitiveType::PRED:
return xla::LiteralUtil::CreateR0<bool>(scalar_value);
case xla::PrimitiveType::C64:
return xla::LiteralUtil::CreateR0<xla::complex64>(scalar_value);
case xla::PrimitiveType::C128:
return xla::LiteralUtil::CreateR0<xla::complex128>(scalar_value);
default:
return xla::LiteralUtil::CreateR0<T>(scalar_value);
}
}