static xla::Literal ScalarLiteral()

in Sources/x10/xla_tensor/helpers.h [54:92]


  static xla::Literal ScalarLiteral(T scalar_value, xla::PrimitiveType type) {
    switch (type) {
      case xla::PrimitiveType::F64:
        return xla::LiteralUtil::CreateR0<double>(scalar_value);
      case xla::PrimitiveType::F32:
        return xla::LiteralUtil::CreateR0<float>(scalar_value);
      case xla::PrimitiveType::BF16:
        return xla::LiteralUtil::CreateR0<tensorflow::bfloat16>(
            static_cast<tensorflow::bfloat16>(
                static_cast<float>(scalar_value)));
      case xla::PrimitiveType::F16:
        return xla::LiteralUtil::CreateR0<xla::half>(
            static_cast<xla::half>(static_cast<float>(scalar_value)));
      case xla::PrimitiveType::S64:
        return xla::LiteralUtil::CreateR0<xla::int64>(scalar_value);
      case xla::PrimitiveType::U64:
        return xla::LiteralUtil::CreateR0<xla::uint64>(scalar_value);
      case xla::PrimitiveType::S32:
        return xla::LiteralUtil::CreateR0<xla::int32>(scalar_value);
      case xla::PrimitiveType::U32:
        return xla::LiteralUtil::CreateR0<xla::uint32>(scalar_value);
      case xla::PrimitiveType::S16:
        return xla::LiteralUtil::CreateR0<xla::int16>(scalar_value);
      case xla::PrimitiveType::U16:
        return xla::LiteralUtil::CreateR0<xla::uint16>(scalar_value);
      case xla::PrimitiveType::S8:
        return xla::LiteralUtil::CreateR0<xla::int8>(scalar_value);
      case xla::PrimitiveType::U8:
        return xla::LiteralUtil::CreateR0<xla::uint8>(scalar_value);
      case xla::PrimitiveType::PRED:
        return xla::LiteralUtil::CreateR0<bool>(scalar_value);
      case xla::PrimitiveType::C64:
        return xla::LiteralUtil::CreateR0<xla::complex64>(scalar_value);
      case xla::PrimitiveType::C128:
        return xla::LiteralUtil::CreateR0<xla::complex128>(scalar_value);
      default:
        return xla::LiteralUtil::CreateR0<T>(scalar_value);
    }
  }