cpp/core/jni/JniCommon.h (428 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include <arrow/ipc/reader.h> #include <arrow/ipc/writer.h> #include <execinfo.h> #include <jni.h> #include "compute/ProtobufUtils.h" #include "compute/Runtime.h" #include "memory/AllocationListener.h" #include "shuffle/rss/RssClient.h" #include "utils/Compression.h" #include "utils/Exception.h" #include "utils/ResourceMap.h" static jint jniVersion = JNI_VERSION_1_8; static inline std::string jStringToCString(JNIEnv* env, jstring string) { int32_t jlen, clen; clen = env->GetStringUTFLength(string); jlen = env->GetStringLength(string); char buffer[clen + 1]; env->GetStringUTFRegion(string, 0, jlen, buffer); return std::string(buffer, clen); } static inline void checkException(JNIEnv* env) { if (env->ExceptionCheck()) { jthrowable t = env->ExceptionOccurred(); env->ExceptionClear(); jclass describerClass = env->FindClass("org/apache/gluten/exception/JniExceptionDescriber"); jmethodID describeMethod = env->GetStaticMethodID(describerClass, "describe", "(Ljava/lang/Throwable;)Ljava/lang/String;"); std::string description = jStringToCString(env, (jstring)env->CallStaticObjectMethod(describerClass, describeMethod, t)); if (env->ExceptionCheck()) { LOG(WARNING) << "Fatal: Uncaught Java exception during calling the Java exception describer method! "; } throw gluten::GlutenException("Error during calling Java code from native code: " + description); } } static inline jclass createGlobalClassReference(JNIEnv* env, const char* className) { jclass localClass = env->FindClass(className); jclass globalClass = (jclass)env->NewGlobalRef(localClass); env->DeleteLocalRef(localClass); return globalClass; } static inline jclass createGlobalClassReferenceOrError(JNIEnv* env, const char* className) { jclass globalClass = createGlobalClassReference(env, className); if (globalClass == nullptr) { std::string errorMessage = "Unable to create global class reference for" + std::string(className); throw gluten::GlutenException(errorMessage); } return globalClass; } static inline jmethodID getMethodId(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = env->GetMethodID(thisClass, name, sig); return ret; } static inline jmethodID getMethodIdOrError(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = getMethodId(env, thisClass, name, sig); if (ret == nullptr) { std::string errorMessage = "Unable to find method " + std::string(name) + " within signature" + std::string(sig); throw gluten::GlutenException(errorMessage); } return ret; } static inline jmethodID getStaticMethodId(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = env->GetStaticMethodID(thisClass, name, sig); return ret; } static inline jmethodID getStaticMethodIdOrError(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = getStaticMethodId(env, thisClass, name, sig); if (ret == nullptr) { std::string errorMessage = "Unable to find static method " + std::string(name) + " within signature" + std::string(sig); throw gluten::GlutenException(errorMessage); } return ret; } static inline void attachCurrentThreadAsDaemonOrThrow(JavaVM* vm, JNIEnv** out) { int getEnvStat = vm->GetEnv(reinterpret_cast<void**>(out), jniVersion); if (getEnvStat == JNI_EDETACHED) { DLOG(INFO) << "JNIEnv was not attached to current thread."; // Reattach current thread to JVM getEnvStat = vm->AttachCurrentThreadAsDaemon(reinterpret_cast<void**>(out), NULL); if (getEnvStat != JNI_OK) { throw gluten::GlutenException("Failed to reattach current thread to JVM."); } DLOG(INFO) << "Succeeded attaching current thread."; return; } if (getEnvStat != JNI_OK) { throw gluten::GlutenException("Failed to attach current thread to JVM."); } } template <typename T> static T* jniCastOrThrow(jlong handle) { auto instance = reinterpret_cast<T*>(handle); GLUTEN_CHECK(instance != nullptr, "FATAL: resource instance should not be null."); return instance; } namespace gluten { class JniCommonState { public: virtual ~JniCommonState() = default; void ensureInitialized(JNIEnv* env); void assertInitialized(); void close(); jmethodID runtimeAwareCtxHandle(); private: void initialize(JNIEnv* env); jclass runtimeAwareClass_; jmethodID runtimeAwareCtxHandle_; JavaVM* vm_; bool initialized_{false}; bool closed_{false}; std::mutex mtx_; }; inline JniCommonState* getJniCommonState() { static JniCommonState jniCommonState; return &jniCommonState; } Runtime* getRuntime(JNIEnv* env, jobject runtimeAware); // Safe version of JNI {Get|Release}<PrimitiveType>ArrayElements routines. // SafeNativeArray would release the managed array elements automatically // during destruction. enum class JniPrimitiveArrayType { kBoolean = 0, kByte = 1, kChar = 2, kShort = 3, kInt = 4, kLong = 5, kFloat = 6, kDouble = 7 }; #define CONCATENATE(t1, t2, t3) t1##t2##t3 #define DEFINE_PRIMITIVE_ARRAY(PRIM_TYPE, JAVA_TYPE, JNI_NATIVE_TYPE, NATIVE_TYPE, METHOD_VAR) \ template <> \ struct JniPrimitiveArray<JniPrimitiveArrayType::PRIM_TYPE> { \ using JavaType = JAVA_TYPE; \ using JniNativeType = JNI_NATIVE_TYPE; \ using NativeType = NATIVE_TYPE; \ \ static JniNativeType get(JNIEnv* env, JavaType javaArray) { \ return env->CONCATENATE(Get, METHOD_VAR, ArrayElements)(javaArray, nullptr); \ } \ \ static void release(JNIEnv* env, JavaType javaArray, JniNativeType nativeArray) { \ env->CONCATENATE(Release, METHOD_VAR, ArrayElements)(javaArray, nativeArray, JNI_ABORT); \ } \ }; template <JniPrimitiveArrayType TYPE> struct JniPrimitiveArray {}; DEFINE_PRIMITIVE_ARRAY(kBoolean, jbooleanArray, jboolean*, bool*, Boolean) DEFINE_PRIMITIVE_ARRAY(kByte, jbyteArray, jbyte*, uint8_t*, Byte) DEFINE_PRIMITIVE_ARRAY(kChar, jcharArray, jchar*, uint16_t*, Char) DEFINE_PRIMITIVE_ARRAY(kShort, jshortArray, jshort*, int16_t*, Short) DEFINE_PRIMITIVE_ARRAY(kInt, jintArray, jint*, int32_t*, Int) DEFINE_PRIMITIVE_ARRAY(kLong, jlongArray, jlong*, int64_t*, Long) DEFINE_PRIMITIVE_ARRAY(kFloat, jfloatArray, jfloat*, float_t*, Float) DEFINE_PRIMITIVE_ARRAY(kDouble, jdoubleArray, jdouble*, double_t*, Double) template <JniPrimitiveArrayType TYPE> class SafeNativeArray { using PrimitiveArray = JniPrimitiveArray<TYPE>; using JavaArrayType = typename PrimitiveArray::JavaType; using JniNativeArrayType = typename PrimitiveArray::JniNativeType; using NativeArrayType = typename PrimitiveArray::NativeType; public: virtual ~SafeNativeArray() { PrimitiveArray::release(env_, javaArray_, nativeArray_); } SafeNativeArray(const SafeNativeArray&) = delete; SafeNativeArray(SafeNativeArray&&) = delete; SafeNativeArray& operator=(const SafeNativeArray&) = delete; SafeNativeArray& operator=(SafeNativeArray&&) = delete; const NativeArrayType elems() const { return reinterpret_cast<const NativeArrayType>(nativeArray_); } const jsize length() const { return env_->GetArrayLength(javaArray_); } static SafeNativeArray<TYPE> get(JNIEnv* env, JavaArrayType javaArray) { JniNativeArrayType nativeArray = PrimitiveArray::get(env, javaArray); return SafeNativeArray<TYPE>(env, javaArray, nativeArray); } private: SafeNativeArray(JNIEnv* env, JavaArrayType javaArray, JniNativeArrayType nativeArray) : env_(env), javaArray_(javaArray), nativeArray_(nativeArray){}; JNIEnv* env_; JavaArrayType javaArray_; JniNativeArrayType nativeArray_; }; #define DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(PRIM_TYPE, JAVA_TYPE, METHOD_VAR) \ inline SafeNativeArray<JniPrimitiveArrayType::PRIM_TYPE> CONCATENATE(get, METHOD_VAR, ArrayElementsSafe)( \ JNIEnv * env, JAVA_TYPE array) { \ return SafeNativeArray<JniPrimitiveArrayType::PRIM_TYPE>::get(env, array); \ } DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kBoolean, jbooleanArray, Boolean) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kByte, jbyteArray, Byte) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kChar, jcharArray, Char) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kShort, jshortArray, Short) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kInt, jintArray, Int) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kLong, jlongArray, Long) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kFloat, jfloatArray, Float) DEFINE_SAFE_GET_PRIMITIVE_ARRAY_FUNCTIONS(kDouble, jdoubleArray, Double) class JniColumnarBatchIterator : public ColumnarBatchIterator { public: explicit JniColumnarBatchIterator( JNIEnv* env, jobject jColumnarBatchItr, Runtime* runtime, std::optional<int32_t> iteratorIndex = std::nullopt); // singleton JniColumnarBatchIterator(const JniColumnarBatchIterator&) = delete; JniColumnarBatchIterator(JniColumnarBatchIterator&&) = delete; JniColumnarBatchIterator& operator=(const JniColumnarBatchIterator&) = delete; JniColumnarBatchIterator& operator=(JniColumnarBatchIterator&&) = delete; ~JniColumnarBatchIterator() override; std::shared_ptr<ColumnarBatch> next() override; private: class ColumnarBatchIteratorDumper final : public ColumnarBatchIterator { public: ColumnarBatchIteratorDumper(JniColumnarBatchIterator* self) : self_(self){}; std::shared_ptr<ColumnarBatch> next() override { return self_->nextInternal(); } private: JniColumnarBatchIterator* self_; }; std::shared_ptr<ColumnarBatch> nextInternal() const; JavaVM* vm_; jobject jColumnarBatchItr_; Runtime* runtime_; std::optional<int32_t> iteratorIndex_; const bool shouldDump_; jclass serializedColumnarBatchIteratorClass_; jmethodID serializedColumnarBatchIteratorHasNext_; jmethodID serializedColumnarBatchIteratorNext_; std::shared_ptr<ColumnarBatchIterator> dumpedIteratorReader_{nullptr}; }; std::unique_ptr<JniColumnarBatchIterator> makeJniColumnarBatchIterator(JNIEnv* env, jobject jColumnarBatchItr, Runtime* runtime); } // namespace gluten // TODO: Move the static functions to namespace gluten static inline void backtrace() { void* array[1024]; auto size = backtrace(array, 1024); char** strings = backtrace_symbols(array, size); for (size_t i = 0; i < size; ++i) { LOG(INFO) << strings[i]; } free(strings); } static inline arrow::Compression::type getCompressionType(JNIEnv* env, jstring codecJstr) { if (codecJstr == NULL) { return arrow::Compression::UNCOMPRESSED; } auto codec = env->GetStringUTFChars(codecJstr, JNI_FALSE); // Convert codec string into lowercase. std::string codecLower; std::transform(codec, codec + std::strlen(codec), std::back_inserter(codecLower), ::tolower); GLUTEN_ASSIGN_OR_THROW(auto compressionType, arrow::util::Codec::GetCompressionType(codecLower)); env->ReleaseStringUTFChars(codecJstr, codec); return compressionType; } static inline gluten::CodecBackend getCodecBackend(JNIEnv* env, jstring codecJstr) { if (codecJstr == nullptr) { return gluten::CodecBackend::NONE; } auto codecBackend = jStringToCString(env, codecJstr); if (codecBackend == "qat") { return gluten::CodecBackend::QAT; } else if (codecBackend == "iaa") { return gluten::CodecBackend::IAA; } else { throw std::invalid_argument("Not support this codec backend " + codecBackend); } } static inline gluten::CompressionMode getCompressionMode(JNIEnv* env, jstring compressionModeJstr) { GLUTEN_DCHECK(compressionModeJstr != nullptr, "CompressionMode cannot be null"); auto compressionMode = jStringToCString(env, compressionModeJstr); if (compressionMode == "buffer") { return gluten::CompressionMode::BUFFER; } else if (compressionMode == "rowvector") { return gluten::CompressionMode::ROWVECTOR; } else { throw std::invalid_argument("Not support this compression mode " + compressionMode); } } /* NOTE: the class must be thread safe */ class SparkAllocationListener final : public gluten::AllocationListener { public: SparkAllocationListener(JavaVM* vm, jobject jListenerLocalRef) : vm_(vm) { JNIEnv* env; attachCurrentThreadAsDaemonOrThrow(vm_, &env); jListenerGlobalRef_ = env->NewGlobalRef(jListenerLocalRef); } SparkAllocationListener(const SparkAllocationListener&) = delete; SparkAllocationListener(SparkAllocationListener&&) = delete; SparkAllocationListener& operator=(const SparkAllocationListener&) = delete; SparkAllocationListener& operator=(SparkAllocationListener&&) = delete; ~SparkAllocationListener() override { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) { LOG(WARNING) << "SparkAllocationListener#~SparkAllocationListener(): " << "JNIEnv was not attached to current thread"; return; } env->DeleteGlobalRef(jListenerGlobalRef_); } void allocationChanged(int64_t size) override { if (size == 0) { return; } JNIEnv* env; attachCurrentThreadAsDaemonOrThrow(vm_, &env); if (size < 0) { env->CallLongMethod(jListenerGlobalRef_, unreserveMemoryMethod(env), -size); checkException(env); } else { env->CallLongMethod(jListenerGlobalRef_, reserveMemoryMethod(env), size); checkException(env); } usedBytes_ += size; while (true) { int64_t savedPeakBytes = peakBytes_; int64_t savedUsedBytes = usedBytes_; if (savedUsedBytes <= savedPeakBytes) { break; } // usedBytes_ > savedPeakBytes, update peak if (peakBytes_.compare_exchange_weak(savedPeakBytes, savedUsedBytes)) { break; } } } int64_t currentBytes() override { return usedBytes_; } int64_t peakBytes() override { return peakBytes_; } private: jclass javaReservationListenerClass(JNIEnv* env) { static jclass javaReservationListenerClass = createGlobalClassReference( env, "Lorg/apache/gluten/memory/listener/" "ReservationListener;"); return javaReservationListenerClass; } jmethodID reserveMemoryMethod(JNIEnv* env) { static jmethodID reserveMemoryMethod = getMethodIdOrError(env, javaReservationListenerClass(env), "reserve", "(J)J"); return reserveMemoryMethod; } jmethodID unreserveMemoryMethod(JNIEnv* env) { static jmethodID unreserveMemoryMethod = getMethodIdOrError(env, javaReservationListenerClass(env), "unreserve", "(J)J"); return unreserveMemoryMethod; } JavaVM* vm_; jobject jListenerGlobalRef_; std::atomic_int64_t usedBytes_{0L}; std::atomic_int64_t peakBytes_{0L}; }; class BacktraceAllocationListener final : public gluten::AllocationListener { public: BacktraceAllocationListener(std::unique_ptr<gluten::AllocationListener> delegator) : delegator_(std::move(delegator)) {} void allocationChanged(int64_t bytes) override { allocationBacktrace(bytes); delegator_->allocationChanged(bytes); } private: void allocationBacktrace(int64_t bytes) { allocatedBytes_ += bytes; if (bytes > (64L << 20)) { backtrace(); } else if (allocatedBytes_ >= backtraceBytes_) { backtrace(); backtraceBytes_ += (1L << 30); } } std::unique_ptr<gluten::AllocationListener> delegator_; std::atomic_int64_t allocatedBytes_{}; std::atomic_int64_t backtraceBytes_{1L << 30}; }; class JavaRssClient : public RssClient { public: JavaRssClient(JavaVM* vm, jobject javaRssShuffleWriter, jmethodID javaPushPartitionDataMethod) : vm_(vm), javaPushPartitionData_(javaPushPartitionDataMethod) { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } javaRssShuffleWriter_ = env->NewGlobalRef(javaRssShuffleWriter); array_ = env->NewByteArray(1024 * 1024); array_ = static_cast<jbyteArray>(env->NewGlobalRef(array_)); } ~JavaRssClient() { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) { LOG(WARNING) << "JavaRssClient#~JavaRssClient(): " << "JNIEnv was not attached to current thread"; return; } env->DeleteGlobalRef(javaRssShuffleWriter_); jbyte* byteArray = env->GetByteArrayElements(array_, NULL); env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); env->DeleteGlobalRef(array_); } int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast<void**>(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } jint length = env->GetArrayLength(array_); if (size > length) { jbyte* byteArray = env->GetByteArrayElements(array_, NULL); env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); env->DeleteGlobalRef(array_); array_ = env->NewByteArray(size); if (array_ == nullptr) { LOG(WARNING) << "Failed to allocate new byte array size: " << size; throw gluten::GlutenException("Failed to allocate new byte array"); } array_ = static_cast<jbyteArray>(env->NewGlobalRef(array_)); } env->SetByteArrayRegion(array_, 0, size, (jbyte*)bytes); jint javaBytesSize = env->CallIntMethod(javaRssShuffleWriter_, javaPushPartitionData_, partitionId, array_, size); checkException(env); return static_cast<int32_t>(javaBytesSize); } void stop() override {} private: JavaVM* vm_; jobject javaRssShuffleWriter_; jmethodID javaPushPartitionData_; jbyteArray array_; };