JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrapper_nativeMake()

in cpp/core/jni/JniWrapper.cc [789:926]


JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ShuffleWriterJniWrapper_nativeMake( // NOLINT
    JNIEnv* env,
    jobject wrapper,
    jstring partitioningNameJstr,
    jint numPartitions,
    jint bufferSize,
    jint mergeBufferSize,
    jdouble mergeThreshold,
    jstring codecJstr,
    jstring codecBackendJstr,
    jint compressionLevel,
    jint compressionBufferSize,
    jint diskWriteBufferSize,
    jint compressionThreshold,
    jstring compressionModeJstr,
    jint initialSortBufferSize,
    jboolean useRadixSort,
    jstring dataFileJstr,
    jint numSubDirs,
    jstring localDirsJstr,
    jdouble reallocThreshold,
    jlong firstBatchHandle,
    jlong taskAttemptId,
    jint startPartitionId,
    jint pushBufferMaxSize,
    jlong sortBufferMaxSize,
    jobject partitionPusher,
    jstring partitionWriterTypeJstr,
    jstring shuffleWriterTypeJstr) {
  JNI_METHOD_START
  auto ctx = getRuntime(env, wrapper);
  if (partitioningNameJstr == nullptr) {
    throw GlutenException(std::string("Short partitioning name can't be null"));
  }

  // Build ShuffleWriterOptions.
  auto shuffleWriterOptions = ShuffleWriterOptions{
      .bufferSize = bufferSize,
      .bufferReallocThreshold = reallocThreshold,
      .partitioning = toPartitioning(jStringToCString(env, partitioningNameJstr)),
      .taskAttemptId = static_cast<int64_t>(taskAttemptId),
      .startPartitionId = startPartitionId,
      .shuffleWriterType = ShuffleWriter::stringToType(jStringToCString(env, shuffleWriterTypeJstr)),
      .initialSortBufferSize = initialSortBufferSize,
      .diskWriteBufferSize = diskWriteBufferSize,
      .useRadixSort = static_cast<bool>(useRadixSort)};

  // Build PartitionWriterOptions.
  auto partitionWriterOptions = PartitionWriterOptions{
      .mergeBufferSize = mergeBufferSize,
      .mergeThreshold = mergeThreshold,
      .compressionBufferSize = compressionBufferSize,
      .compressionThreshold = compressionThreshold,
      .compressionType = getCompressionType(env, codecJstr),
      .compressionLevel = compressionLevel,
      .bufferedWrite = true,
      .numSubDirs = numSubDirs,
      .pushBufferMaxSize = pushBufferMaxSize > 0 ? pushBufferMaxSize : kDefaultPushMemoryThreshold,
      .sortBufferMaxSize = sortBufferMaxSize > 0 ? sortBufferMaxSize : kDefaultSortBufferThreshold};
  if (codecJstr != NULL) {
    partitionWriterOptions.codecBackend = getCodecBackend(env, codecBackendJstr);
    partitionWriterOptions.compressionMode = getCompressionMode(env, compressionModeJstr);
  }
  const auto& conf = ctx->getConfMap();
  {
    auto it = conf.find(kShuffleFileBufferSize);
    if (it != conf.end()) {
      partitionWriterOptions.shuffleFileBufferSize = static_cast<int64_t>(stoi(it->second));
    }
  }

  std::unique_ptr<PartitionWriter> partitionWriter;

  auto partitionWriterTypeC = env->GetStringUTFChars(partitionWriterTypeJstr, JNI_FALSE);
  auto partitionWriterType = std::string(partitionWriterTypeC);
  env->ReleaseStringUTFChars(partitionWriterTypeJstr, partitionWriterTypeC);

  if (partitionWriterType == "local") {
    if (dataFileJstr == NULL) {
      throw GlutenException(std::string("Shuffle DataFile can't be null"));
    }
    if (localDirsJstr == NULL) {
      throw GlutenException(std::string("Shuffle DataFile can't be null"));
    }
    auto dataFileC = env->GetStringUTFChars(dataFileJstr, JNI_FALSE);
    auto dataFile = std::string(dataFileC);
    env->ReleaseStringUTFChars(dataFileJstr, dataFileC);

    auto localDirsC = env->GetStringUTFChars(localDirsJstr, JNI_FALSE);
    auto configuredDirs = splitPaths(std::string(localDirsC));
    env->ReleaseStringUTFChars(localDirsJstr, localDirsC);

    partitionWriter = std::make_unique<LocalPartitionWriter>(
        numPartitions,
        std::move(partitionWriterOptions),
        ctx->memoryManager()->getArrowMemoryPool(),
        dataFile,
        configuredDirs);
  } else if (partitionWriterType == "celeborn") {
    jclass celebornPartitionPusherClass =
        createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/CelebornPartitionPusher;");
    jmethodID celebornPushPartitionDataMethod =
        getMethodIdOrError(env, celebornPartitionPusherClass, "pushPartitionData", "(I[BI)I");
    JavaVM* vm;
    if (env->GetJavaVM(&vm) != JNI_OK) {
      throw GlutenException("Unable to get JavaVM instance");
    }
    std::shared_ptr<JavaRssClient> celebornClient =
        std::make_shared<JavaRssClient>(vm, partitionPusher, celebornPushPartitionDataMethod);
    partitionWriter = std::make_unique<RssPartitionWriter>(
        numPartitions,
        std::move(partitionWriterOptions),
        ctx->memoryManager()->getArrowMemoryPool(),
        std::move(celebornClient));
  } else if (partitionWriterType == "uniffle") {
    jclass unifflePartitionPusherClass =
        createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/writer/PartitionPusher;");
    jmethodID unifflePushPartitionDataMethod =
        getMethodIdOrError(env, unifflePartitionPusherClass, "pushPartitionData", "(I[BI)I");
    JavaVM* vm;
    if (env->GetJavaVM(&vm) != JNI_OK) {
      throw GlutenException("Unable to get JavaVM instance");
    }
    std::shared_ptr<JavaRssClient> uniffleClient =
        std::make_shared<JavaRssClient>(vm, partitionPusher, unifflePushPartitionDataMethod);
    partitionWriter = std::make_unique<RssPartitionWriter>(
        numPartitions,
        std::move(partitionWriterOptions),
        ctx->memoryManager()->getArrowMemoryPool(),
        std::move(uniffleClient));
  } else {
    throw GlutenException("Unrecognizable partition writer type: " + partitionWriterType);
  }

  return ctx->saveObject(
      ctx->createShuffleWriter(numPartitions, std::move(partitionWriter), std::move(shuffleWriterOptions)));
  JNI_METHOD_END(kInvalidObjectHandle)
}