jsi::Value JavaTurboModule::invokeJavaMethod()

in ReactCommon/react/nativemodule/core/platform/android/ReactCommon/JavaTurboModule.cpp [433:860]


jsi::Value JavaTurboModule::invokeJavaMethod(
    jsi::Runtime &runtime,
    TurboModuleMethodValueKind valueKind,
    const std::string &methodNameStr,
    const std::string &methodSignature,
    const jsi::Value *args,
    size_t argCount) {
  const char *methodName = methodNameStr.c_str();
  const char *moduleName = name_.c_str();

  bool isMethodSync = !(valueKind == VoidKind || valueKind == PromiseKind);

  if (isMethodSync) {
    TMPL::syncMethodCallStart(moduleName, methodName);
    TMPL::syncMethodCallArgConversionStart(moduleName, methodName);
  } else {
    TMPL::asyncMethodCallStart(moduleName, methodName);
    TMPL::asyncMethodCallArgConversionStart(moduleName, methodName);
  }

  JNIEnv *env = jni::Environment::current();
  auto instance = instance_.get();

  /**
   * To account for jclasses and other misc LocalReferences we create.
   */
  unsigned int buffer = 6;
  /**
   * For promises, we have to create a resolve fn, a reject fn, and a promise
   * object. For normal returns, we just create the return object.
   */
  unsigned int maxReturnObjects = 3;

  /**
   * When the return type is void, all JNI LocalReferences are converted to
   * GlobalReferences. The LocalReferences are then promptly deleted
   * after the conversion.
   */
  unsigned int actualArgCount = valueKind == VoidKind ? 0 : argCount;
  unsigned int estimatedLocalRefCount =
      actualArgCount + maxReturnObjects + buffer;

  /**
   * This will push a new JNI stack frame for the LocalReferences in this
   * function call. When the stack frame for invokeJavaMethod is popped,
   * all LocalReferences are deleted.
   *
   * In total, there can be at most kJniLocalRefMax (= 512) Jni
   * LocalReferences alive at a time. estimatedLocalRefCount is provided
   * so that PushLocalFrame can throw an out of memory error when the total
   * number of alive LocalReferences is estimatedLocalRefCount smaller than
   * kJniLocalRefMax.
   */
  jni::JniLocalScope scope(env, estimatedLocalRefCount);

  jclass cls = env->GetObjectClass(instance);
  jmethodID methodID =
      env->GetMethodID(cls, methodName, methodSignature.c_str());

  auto checkJNIErrorForMethodCall =
      [methodName, moduleName, isMethodSync]() -> void {
    try {
      FACEBOOK_JNI_THROW_PENDING_EXCEPTION();
    } catch (...) {
      if (isMethodSync) {
        TMPL::syncMethodCallFail(moduleName, methodName);
      } else {
        TMPL::asyncMethodCallFail(moduleName, methodName);
      }
      throw;
    }
  };

  // If the method signature doesn't match, show a redbox here instead of
  // crashing later.
  checkJNIErrorForMethodCall();

  // TODO(T43933641): Refactor to remove this special-casing
  if (methodNameStr == "getConstants") {
    TMPL::syncMethodCallArgConversionEnd(moduleName, methodName);
    TMPL::syncMethodCallExecutionStart(moduleName, methodName);

    auto constantsMap = (jobject)env->CallObjectMethod(instance, methodID);
    checkJNIErrorForMethodCall();

    TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
    TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

    jsi::Value returnValue = constantsMap == nullptr
        ? jsi::Value::undefined()
        : convertFromJMapToValue(env, runtime, constantsMap);

    TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
    TMPL::syncMethodCallEnd(moduleName, methodName);
    return returnValue;
  }

  std::vector<std::string> methodArgTypes =
      getMethodArgTypesFromSignature(methodSignature);

  JNIArgs jniArgs = convertJSIArgsToJNIArgs(
      env,
      runtime,
      methodNameStr,
      methodArgTypes,
      args,
      argCount,
      jsInvoker_,
      valueKind,
      retainJSCallback_);

  if (isMethodSync && valueKind != PromiseKind) {
    TMPL::syncMethodCallArgConversionEnd(moduleName, methodName);
    TMPL::syncMethodCallExecutionStart(moduleName, methodName);
  }

  auto &jargs = jniArgs.args_;
  auto &globalRefs = jniArgs.globalRefs_;

  switch (valueKind) {
    case BooleanKind: {
      std::string returnType =
          methodSignature.substr(methodSignature.find_last_of(')') + 1);
      if (returnType == "Ljava/lang/Boolean;") {
        auto returnObject =
            (jobject)env->CallObjectMethodA(instance, methodID, jargs.data());
        checkJNIErrorForMethodCall();

        TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
        TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

        jsi::Value returnValue = jsi::Value::null();
        if (returnObject != nullptr) {
          jclass booleanClass = env->FindClass("java/lang/Boolean");
          jmethodID booleanValueMethod =
              env->GetMethodID(booleanClass, "booleanValue", "()Z");
          bool returnBoolean =
              (bool)env->CallBooleanMethod(returnObject, booleanValueMethod);
          checkJNIErrorForMethodCall();
          returnValue = jsi::Value(returnBoolean);
        }

        TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
        TMPL::syncMethodCallEnd(moduleName, methodName);
        return returnValue;
      }

      bool returnBoolean =
          (bool)env->CallBooleanMethodA(instance, methodID, jargs.data());
      checkJNIErrorForMethodCall();

      TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
      TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

      jsi::Value returnValue = jsi::Value(returnBoolean);

      TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
      TMPL::syncMethodCallEnd(moduleName, methodName);

      return returnValue;
    }
    case NumberKind: {
      std::string returnType =
          methodSignature.substr(methodSignature.find_last_of(')') + 1);
      if (returnType == "Ljava/lang/Double;") {
        auto returnObject =
            (jobject)env->CallObjectMethodA(instance, methodID, jargs.data());
        checkJNIErrorForMethodCall();

        TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
        TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

        jsi::Value returnValue = jsi::Value::null();
        if (returnObject != nullptr) {
          jclass doubleClass = env->FindClass("java/lang/Double");
          jmethodID doubleValueMethod =
              env->GetMethodID(doubleClass, "doubleValue", "()D");
          double returnDouble =
              (double)env->CallDoubleMethod(returnObject, doubleValueMethod);
          checkJNIErrorForMethodCall();
          returnValue = jsi::Value(returnDouble);
        }

        TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
        TMPL::syncMethodCallEnd(moduleName, methodName);
        return returnValue;
      }

      double returnDouble =
          (double)env->CallDoubleMethodA(instance, methodID, jargs.data());
      checkJNIErrorForMethodCall();

      TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
      TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

      jsi::Value returnValue = jsi::Value(returnDouble);

      TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
      TMPL::syncMethodCallEnd(moduleName, methodName);
      return returnValue;
    }
    case StringKind: {
      auto returnString =
          (jstring)env->CallObjectMethodA(instance, methodID, jargs.data());
      checkJNIErrorForMethodCall();

      TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
      TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

      jsi::Value returnValue = jsi::Value::null();
      if (returnString != nullptr) {
        const char *js = env->GetStringUTFChars(returnString, nullptr);
        std::string result = js;
        env->ReleaseStringUTFChars(returnString, js);
        returnValue =
            jsi::Value(runtime, jsi::String::createFromUtf8(runtime, result));
      }

      TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
      TMPL::syncMethodCallEnd(moduleName, methodName);
      return returnValue;
    }
    case ObjectKind: {
      auto returnObject =
          (jobject)env->CallObjectMethodA(instance, methodID, jargs.data());
      checkJNIErrorForMethodCall();

      TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
      TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

      jsi::Value returnValue = jsi::Value::null();
      if (returnObject != nullptr) {
        auto jResult = jni::adopt_local(returnObject);
        auto result = jni::static_ref_cast<NativeMap::jhybridobject>(jResult);
        returnValue =
            jsi::valueFromDynamic(runtime, result->cthis()->consume());
      }

      TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
      TMPL::syncMethodCallEnd(moduleName, methodName);
      return returnValue;
    }
    case ArrayKind: {
      auto returnObject =
          (jobject)env->CallObjectMethodA(instance, methodID, jargs.data());
      checkJNIErrorForMethodCall();

      TMPL::syncMethodCallExecutionEnd(moduleName, methodName);
      TMPL::syncMethodCallReturnConversionStart(moduleName, methodName);

      jsi::Value returnValue = jsi::Value::null();
      if (returnObject != nullptr) {
        auto jResult = jni::adopt_local(returnObject);
        auto result = jni::static_ref_cast<NativeArray::jhybridobject>(jResult);
        returnValue =
            jsi::valueFromDynamic(runtime, result->cthis()->consume());
      }

      TMPL::syncMethodCallReturnConversionEnd(moduleName, methodName);
      TMPL::syncMethodCallEnd(moduleName, methodName);
      return returnValue;
    }
    case VoidKind: {
      TMPL::asyncMethodCallArgConversionEnd(moduleName, methodName);
      TMPL::asyncMethodCallDispatch(moduleName, methodName);

      nativeInvoker_->invokeAsync(
          [jargs,
           globalRefs,
           methodID,
           instance_ = jni::make_weak(instance_),
           moduleNameStr = name_,
           methodNameStr,
           id = getUniqueId()]() mutable -> void {
            auto instance = instance_.lockLocal();
            if (!instance) {
              return;
            }
            /**
             * TODO(ramanpreet): Why do we have to require the environment
             * again? Why does JNI crash when we use the env from the upper
             * scope?
             */
            JNIEnv *env = jni::Environment::current();
            const char *moduleName = moduleNameStr.c_str();
            const char *methodName = methodNameStr.c_str();

            TMPL::asyncMethodCallExecutionStart(moduleName, methodName, id);
            env->CallVoidMethodA(instance.get(), methodID, jargs.data());
            try {
              FACEBOOK_JNI_THROW_PENDING_EXCEPTION();
            } catch (...) {
              TMPL::asyncMethodCallExecutionFail(moduleName, methodName, id);
              throw;
            }

            for (auto globalRef : globalRefs) {
              env->DeleteGlobalRef(globalRef);
            }
            TMPL::asyncMethodCallExecutionEnd(moduleName, methodName, id);
          });

      TMPL::asyncMethodCallEnd(moduleName, methodName);
      return jsi::Value::undefined();
    }
    case PromiseKind: {
      jsi::Function Promise =
          runtime.global().getPropertyAsFunction(runtime, "Promise");

      jsi::Function promiseConstructorArg = jsi::Function::createFromHostFunction(
          runtime,
          jsi::PropNameID::forAscii(runtime, "fn"),
          2,
          [this,
           &jargs,
           &globalRefs,
           argCount,
           methodID,
           moduleNameStr = name_,
           methodNameStr,
           env,
           retainJSCallback = retainJSCallback_](
              jsi::Runtime &runtime,
              const jsi::Value &thisVal,
              const jsi::Value *promiseConstructorArgs,
              size_t promiseConstructorArgCount) {
            if (promiseConstructorArgCount != 2) {
              throw std::invalid_argument("Promise fn arg count must be 2");
            }

            jsi::Function resolveJSIFn =
                promiseConstructorArgs[0].getObject(runtime).getFunction(
                    runtime);
            jsi::Function rejectJSIFn =
                promiseConstructorArgs[1].getObject(runtime).getFunction(
                    runtime);

            auto resolve = createJavaCallbackFromJSIFunction(
                               retainJSCallback,
                               std::move(resolveJSIFn),
                               runtime,
                               jsInvoker_)
                               .release();
            auto reject = createJavaCallbackFromJSIFunction(
                              retainJSCallback,
                              std::move(rejectJSIFn),
                              runtime,
                              jsInvoker_)
                              .release();

            jclass jPromiseImpl =
                env->FindClass("com/facebook/react/bridge/PromiseImpl");
            jmethodID jPromiseImplConstructor = env->GetMethodID(
                jPromiseImpl,
                "<init>",
                "(Lcom/facebook/react/bridge/Callback;Lcom/facebook/react/bridge/Callback;)V");

            jobject promise = env->NewObject(
                jPromiseImpl, jPromiseImplConstructor, resolve, reject);

            const char *moduleName = moduleNameStr.c_str();
            const char *methodName = methodNameStr.c_str();

            jobject globalPromise = env->NewGlobalRef(promise);

            globalRefs.push_back(globalPromise);
            env->DeleteLocalRef(promise);

            jargs[argCount].l = globalPromise;
            TMPL::asyncMethodCallArgConversionEnd(moduleName, methodName);
            TMPL::asyncMethodCallDispatch(moduleName, methodName);

            nativeInvoker_->invokeAsync(
                [jargs,
                 globalRefs,
                 methodID,
                 instance_ = jni::make_weak(instance_),
                 moduleNameStr,
                 methodNameStr,
                 id = getUniqueId()]() mutable -> void {
                  auto instance = instance_.lockLocal();

                  if (!instance) {
                    return;
                  }
                  /**
                   * TODO(ramanpreet): Why do we have to require the
                   * environment again? Why does JNI crash when we use the env
                   * from the upper scope?
                   */
                  JNIEnv *env = jni::Environment::current();
                  const char *moduleName = moduleNameStr.c_str();
                  const char *methodName = methodNameStr.c_str();

                  TMPL::asyncMethodCallExecutionStart(
                      moduleName, methodName, id);
                  env->CallVoidMethodA(instance.get(), methodID, jargs.data());
                  try {
                    FACEBOOK_JNI_THROW_PENDING_EXCEPTION();
                  } catch (...) {
                    TMPL::asyncMethodCallExecutionFail(
                        moduleName, methodName, id);
                    throw;
                  }

                  for (auto globalRef : globalRefs) {
                    env->DeleteGlobalRef(globalRef);
                  }
                  TMPL::asyncMethodCallExecutionEnd(moduleName, methodName, id);
                });

            return jsi::Value::undefined();
          });

      jsi::Value promise =
          Promise.callAsConstructor(runtime, promiseConstructorArg);
      checkJNIErrorForMethodCall();

      TMPL::asyncMethodCallEnd(moduleName, methodName);

      return promise;
    }
    default:
      throw std::runtime_error(
          "Unable to find method module: " + methodNameStr + "(" +
          methodSignature + ")");
  }
}