/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include "compute/ProtobufUtils.h" #include "compute/Runtime.h" #include "config/GlutenConfig.h" #include "memory/AllocationListener.h" #include "shuffle/rss/RssClient.h" #include "utils/Compression.h" #include "utils/exception.h" static jint jniVersion = JNI_VERSION_1_8; static inline std::string jStringToCString(JNIEnv* env, jstring string) { int32_t jlen, clen; clen = env->GetStringUTFLength(string); jlen = env->GetStringLength(string); char buffer[clen]; env->GetStringUTFRegion(string, 0, jlen, buffer); return std::string(buffer, clen); } static inline void checkException(JNIEnv* env) { if (env->ExceptionCheck()) { jthrowable t = env->ExceptionOccurred(); env->ExceptionClear(); jclass describerClass = env->FindClass("io/glutenproject/exception/JniExceptionDescriber"); jmethodID describeMethod = env->GetStaticMethodID(describerClass, "describe", "(Ljava/lang/Throwable;)Ljava/lang/String;"); std::string description = jStringToCString(env, (jstring)env->CallStaticObjectMethod(describerClass, describeMethod, t)); if (env->ExceptionCheck()) { LOG(WARNING) << "Fatal: Uncaught Java exception during calling the Java exception describer method! "; } throw gluten::GlutenException("Error during calling Java code from native code: " + description); } } static inline jclass createGlobalClassReference(JNIEnv* env, const char* className) { jclass localClass = env->FindClass(className); jclass globalClass = (jclass)env->NewGlobalRef(localClass); env->DeleteLocalRef(localClass); return globalClass; } static inline jclass createGlobalClassReferenceOrError(JNIEnv* env, const char* className) { jclass globalClass = createGlobalClassReference(env, className); if (globalClass == nullptr) { std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for" + std::string(className); throw gluten::GlutenException(errorMessage); } return globalClass; } static inline jmethodID getMethodId(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = env->GetMethodID(thisClass, name, sig); return ret; } static inline jmethodID getMethodIdOrError(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = getMethodId(env, thisClass, name, sig); if (ret == nullptr) { std::string errorMessage = "Unable to find method " + std::string(name) + " within signature" + std::string(sig); throw gluten::GlutenException(errorMessage); } return ret; } static inline jmethodID getStaticMethodId(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = env->GetStaticMethodID(thisClass, name, sig); return ret; } static inline jmethodID getStaticMethodIdOrError(JNIEnv* env, jclass thisClass, const char* name, const char* sig) { jmethodID ret = getStaticMethodId(env, thisClass, name, sig); if (ret == nullptr) { std::string errorMessage = "Unable to find static method " + std::string(name) + " within signature" + std::string(sig); throw gluten::GlutenException(errorMessage); } return ret; } static inline void attachCurrentThreadAsDaemonOrThrow(JavaVM* vm, JNIEnv** out) { int getEnvStat = vm->GetEnv(reinterpret_cast(out), jniVersion); if (getEnvStat == JNI_EDETACHED) { DLOG(INFO) << "JNIEnv was not attached to current thread."; // Reattach current thread to JVM getEnvStat = vm->AttachCurrentThreadAsDaemon(reinterpret_cast(out), NULL); if (getEnvStat != JNI_OK) { throw gluten::GlutenException("Failed to reattach current thread to JVM."); } DLOG(INFO) << "Succeeded attaching current thread."; return; } if (getEnvStat != JNI_OK) { throw gluten::GlutenException("Failed to attach current thread to JVM."); } } namespace gluten { class JniCommonState { public: virtual ~JniCommonState() = default; void ensureInitialized(JNIEnv* env); void assertInitialized(); void close(); jmethodID runtimeAwareCtxHandle(); private: void initialize(JNIEnv* env); jclass runtimeAwareClass_; jmethodID runtimeAwareCtxHandle_; JavaVM* vm_; bool initialized_{false}; bool closed_{false}; std::mutex mtx_; }; inline JniCommonState* getJniCommonState() { static JniCommonState jniCommonState; return &jniCommonState; } Runtime* getRuntime(JNIEnv* env, jobject runtimeAware); } // namespace gluten // TODO: Move the static functions to namespace gluten static inline void backtrace() { void* array[1024]; auto size = backtrace(array, 1024); char** strings = backtrace_symbols(array, size); for (size_t i = 0; i < size; ++i) { LOG(INFO) << strings[i]; } free(strings); } static inline arrow::Compression::type getCompressionType(JNIEnv* env, jstring codecJstr) { if (codecJstr == NULL) { return arrow::Compression::UNCOMPRESSED; } auto codec = env->GetStringUTFChars(codecJstr, JNI_FALSE); // Convert codec string into lowercase. std::string codecLower; std::transform(codec, codec + std::strlen(codec), std::back_inserter(codecLower), ::tolower); GLUTEN_ASSIGN_OR_THROW(auto compressionType, arrow::util::Codec::GetCompressionType(codecLower)); env->ReleaseStringUTFChars(codecJstr, codec); return compressionType; } static inline gluten::CodecBackend getCodecBackend(JNIEnv* env, jstring codecJstr) { if (codecJstr == nullptr) { return gluten::CodecBackend::NONE; } auto codecBackend = jStringToCString(env, codecJstr); if (codecBackend == "qat") { return gluten::CodecBackend::QAT; } else if (codecBackend == "iaa") { return gluten::CodecBackend::IAA; } else { throw std::invalid_argument("Not support this codec backend " + codecBackend); } } static inline gluten::CompressionMode getCompressionMode(JNIEnv* env, jstring compressionModeJstr) { GLUTEN_DCHECK(compressionModeJstr != nullptr, "CompressionMode cannot be null"); auto compressionMode = jStringToCString(env, compressionModeJstr); if (compressionMode == "buffer") { return gluten::CompressionMode::BUFFER; } else if (compressionMode == "rowvector") { return gluten::CompressionMode::ROWVECTOR; } else { throw std::invalid_argument("Not support this compression mode " + compressionMode); } } class SparkAllocationListener final : public gluten::AllocationListener { public: SparkAllocationListener( JavaVM* vm, jobject jListenerLocalRef, jmethodID jReserveMethod, jmethodID jUnreserveMethod, int64_t blockSize) : vm_(vm), jReserveMethod_(jReserveMethod), jUnreserveMethod_(jUnreserveMethod), blockSize_(blockSize) { JNIEnv* env; attachCurrentThreadAsDaemonOrThrow(vm_, &env); jListenerGlobalRef_ = env->NewGlobalRef(jListenerLocalRef); } SparkAllocationListener(const SparkAllocationListener&) = delete; SparkAllocationListener(SparkAllocationListener&&) = delete; SparkAllocationListener& operator=(const SparkAllocationListener&) = delete; SparkAllocationListener& operator=(SparkAllocationListener&&) = delete; ~SparkAllocationListener() override { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { LOG(WARNING) << "SparkAllocationListener#~SparkAllocationListener(): " << "JNIEnv was not attached to current thread"; return; } env->DeleteGlobalRef(jListenerGlobalRef_); } void allocationChanged(int64_t size) override { updateReservation(size); } private: int64_t reserve(int64_t diff) { std::lock_guard lock(mutex_); bytesReserved_ += diff; int64_t newBlockCount; if (bytesReserved_ == 0) { newBlockCount = 0; } else { // ceil to get the required block number newBlockCount = (bytesReserved_ - 1) / blockSize_ + 1; } int64_t bytesGranted = (newBlockCount - blocksReserved_) * blockSize_; blocksReserved_ = newBlockCount; if (bytesReserved_ > maxBytesReserved_) { maxBytesReserved_ = bytesReserved_; } return bytesGranted; } void updateReservation(int64_t diff) { int64_t granted = reserve(diff); if (granted == 0) { return; } JNIEnv* env; attachCurrentThreadAsDaemonOrThrow(vm_, &env); if (granted < 0) { env->CallLongMethod(jListenerGlobalRef_, jUnreserveMethod_, -granted); checkException(env); return; } env->CallLongMethod(jListenerGlobalRef_, jReserveMethod_, granted); checkException(env); } JavaVM* vm_; jobject jListenerGlobalRef_; jmethodID jReserveMethod_; jmethodID jUnreserveMethod_; int64_t blockSize_; int64_t blocksReserved_ = 0L; int64_t bytesReserved_ = 0L; int64_t maxBytesReserved_ = 0L; std::mutex mutex_; }; class BacktraceAllocationListener final : public gluten::AllocationListener { public: BacktraceAllocationListener(std::unique_ptr delegator) : delegator_(std::move(delegator)) {} void allocationChanged(int64_t bytes) override { allocationBacktrace(bytes); delegator_->allocationChanged(bytes); } private: void allocationBacktrace(int64_t bytes) { allocatedBytes_ += bytes; if (bytes > (64L << 20)) { backtrace(); } else if (allocatedBytes_ >= backtraceBytes_) { backtrace(); backtraceBytes_ += (1L << 30); } } std::unique_ptr delegator_; std::atomic_int64_t allocatedBytes_{}; std::atomic_int64_t backtraceBytes_{1L << 30}; }; class CelebornClient : public RssClient { public: CelebornClient(JavaVM* vm, jobject javaCelebornShuffleWriter, jmethodID javaCelebornPushPartitionDataMethod) : vm_(vm), javaCelebornPushPartitionData_(javaCelebornPushPartitionDataMethod) { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } javaCelebornShuffleWriter_ = env->NewGlobalRef(javaCelebornShuffleWriter); array_ = env->NewByteArray(1024 * 1024); array_ = static_cast(env->NewGlobalRef(array_)); } ~CelebornClient() { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { LOG(WARNING) << "CelebornClient#~CelebornClient(): " << "JNIEnv was not attached to current thread"; return; } env->DeleteGlobalRef(javaCelebornShuffleWriter_); jbyte* byteArray = env->GetByteArrayElements(array_, NULL); env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); env->DeleteGlobalRef(array_); } int32_t pushPartitionData(int32_t partitionId, char* bytes, int64_t size) override { JNIEnv* env; if (vm_->GetEnv(reinterpret_cast(&env), jniVersion) != JNI_OK) { throw gluten::GlutenException("JNIEnv was not attached to current thread"); } jint length = env->GetArrayLength(array_); if (size > length) { jbyte* byteArray = env->GetByteArrayElements(array_, NULL); env->ReleaseByteArrayElements(array_, byteArray, JNI_ABORT); env->DeleteGlobalRef(array_); array_ = env->NewByteArray(size); array_ = static_cast(env->NewGlobalRef(array_)); } env->SetByteArrayRegion(array_, 0, size, reinterpret_cast(bytes)); jint celebornBytesSize = env->CallIntMethod(javaCelebornShuffleWriter_, javaCelebornPushPartitionData_, partitionId, array_, size); checkException(env); return static_cast(celebornBytesSize); } void stop() override {} private: JavaVM* vm_; jobject javaCelebornShuffleWriter_; jmethodID javaCelebornPushPartitionData_; jbyteArray array_; };