in source/neuropod/bindings/java/src/main/native/com_uber_neuropod_Neuropod.cc [229:299]
JNIEXPORT jobject JNICALL Java_com_uber_neuropod_Neuropod_nativeInfer(
JNIEnv *env, jclass, jobjectArray entryArray, jobject requestedOutputsJava, jlong modelHandle)
{
try
{
// Prepare requestedOutputs
std::vector<std::string> requestedOutputs;
if (requestedOutputsJava != nullptr)
{
jsize size = env->CallIntMethod(requestedOutputsJava, njni::java_util_ArrayList_size);
for (jsize i = 0; i < size; i++)
{
jstring element =
static_cast<jstring>(env->CallObjectMethod(requestedOutputsJava, njni::java_util_ArrayList_get, i));
requestedOutputs.emplace_back(njni::to_string(env, element));
env->DeleteLocalRef(element);
}
}
// Fill in NeuropodValueMap
jsize entrySize = env->GetArrayLength(entryArray);
neuropod::NeuropodValueMap nativeMap;
for (jsize i = 0; i < entrySize; i++)
{
jobject entry = env->GetObjectArrayElement(entryArray, i);
std::string key = njni::to_string(
env, static_cast<jstring>(env->CallObjectMethod(entry, njni::java_util_Map_Entry_getKey)));
jobject value = env->CallObjectMethod(entry, njni::java_util_Map_Entry_getValue);
jlong tensorHandle = env->CallLongMethod(value, njni::com_uber_neuropod_NeuropodTensor_getHandle);
if (tensorHandle == 0 || env->ExceptionCheck())
{
throw std::runtime_error("invalid tensor handle");
}
nativeMap.insert(
std::make_pair(key, *reinterpret_cast<std::shared_ptr<neuropod::NeuropodValue> *>(tensorHandle)));
env->DeleteLocalRef(entry);
env->DeleteLocalRef(value);
}
auto model = reinterpret_cast<neuropod::Neuropod *>(modelHandle);
auto inferredMap = model->infer(nativeMap, requestedOutputs);
// Put data to Java Map
auto ret = env->NewObject(njni::java_util_HashMap, njni::java_util_HashMap_);
if (!ret || env->ExceptionCheck())
{
throw std::runtime_error("NewObject failed: cannot create HashMap");
}
for (auto &entry : *inferredMap)
{
jobject javaTensor = env->NewObject(njni::com_uber_neuropod_NeuropodTensor,
njni::com_uber_neuropod_NeuropodTensor_,
reinterpret_cast<jlong>(njni::toHeap(entry.second)));
if (!javaTensor || env->ExceptionCheck())
{
throw std::runtime_error("NewObject failed: cannot create Tensor");
}
env->CallObjectMethod(ret, njni::java_util_HashMap_put, env->NewStringUTF(entry.first.c_str()), javaTensor);
env->DeleteLocalRef(javaTensor);
}
return ret;
}
catch (const std::exception &e)
{
njni::throw_java_exception(env, e.what());
}
return nullptr;
}