JNIEXPORT jobject JNICALL Java_com_uber_neuropod_Neuropod_nativeInfer()

in source/neuropod/bindings/java/src/main/native/com_uber_neuropod_Neuropod.cc [229:299]


JNIEXPORT jobject JNICALL Java_com_uber_neuropod_Neuropod_nativeInfer(
    JNIEnv *env, jclass, jobjectArray entryArray, jobject requestedOutputsJava, jlong modelHandle)
{
    try
    {
        // Prepare requestedOutputs
        std::vector<std::string> requestedOutputs;
        if (requestedOutputsJava != nullptr)
        {
            jsize size = env->CallIntMethod(requestedOutputsJava, njni::java_util_ArrayList_size);
            for (jsize i = 0; i < size; i++)
            {
                jstring element =
                    static_cast<jstring>(env->CallObjectMethod(requestedOutputsJava, njni::java_util_ArrayList_get, i));
                requestedOutputs.emplace_back(njni::to_string(env, element));
                env->DeleteLocalRef(element);
            }
        }

        // Fill in NeuropodValueMap
        jsize                      entrySize = env->GetArrayLength(entryArray);
        neuropod::NeuropodValueMap nativeMap;
        for (jsize i = 0; i < entrySize; i++)
        {
            jobject     entry = env->GetObjectArrayElement(entryArray, i);
            std::string key   = njni::to_string(
                env, static_cast<jstring>(env->CallObjectMethod(entry, njni::java_util_Map_Entry_getKey)));
            jobject value        = env->CallObjectMethod(entry, njni::java_util_Map_Entry_getValue);
            jlong   tensorHandle = env->CallLongMethod(value, njni::com_uber_neuropod_NeuropodTensor_getHandle);
            if (tensorHandle == 0 || env->ExceptionCheck())
            {
                throw std::runtime_error("invalid tensor handle");
            }
            nativeMap.insert(
                std::make_pair(key, *reinterpret_cast<std::shared_ptr<neuropod::NeuropodValue> *>(tensorHandle)));
            env->DeleteLocalRef(entry);
            env->DeleteLocalRef(value);
        }

        auto model       = reinterpret_cast<neuropod::Neuropod *>(modelHandle);
        auto inferredMap = model->infer(nativeMap, requestedOutputs);

        // Put data to Java Map
        auto ret = env->NewObject(njni::java_util_HashMap, njni::java_util_HashMap_);
        if (!ret || env->ExceptionCheck())
        {
            throw std::runtime_error("NewObject failed: cannot create HashMap");
        }

        for (auto &entry : *inferredMap)
        {
            jobject javaTensor = env->NewObject(njni::com_uber_neuropod_NeuropodTensor,
                                                njni::com_uber_neuropod_NeuropodTensor_,
                                                reinterpret_cast<jlong>(njni::toHeap(entry.second)));

            if (!javaTensor || env->ExceptionCheck())
            {
                throw std::runtime_error("NewObject failed: cannot create Tensor");
            }

            env->CallObjectMethod(ret, njni::java_util_HashMap_put, env->NewStringUTF(entry.first.c_str()), javaTensor);
            env->DeleteLocalRef(javaTensor);
        }
        return ret;
    }
    catch (const std::exception &e)
    {
        njni::throw_java_exception(env, e.what());
    }
    return nullptr;
}