in source/neuropod/bindings/java/src/main/native/com_uber_neuropod_NeuropodTensor.cc [36:73]
JNIEXPORT jobject JNICALL Java_com_uber_neuropod_NeuropodTensor_nativeGetBuffer(JNIEnv *env,
jclass /*unused*/,
jlong nativeHandle)
{
try
{
auto neuropodTensor =
(*reinterpret_cast<std::shared_ptr<neuropod::NeuropodValue> *>(nativeHandle))->as_tensor();
auto tensorType = neuropodTensor->get_tensor_type();
switch (tensorType)
{
case neuropod::FLOAT_TENSOR: {
return njni::createDirectBuffer<float>(env, neuropodTensor);
}
case neuropod::DOUBLE_TENSOR: {
return njni::createDirectBuffer<double>(env, neuropodTensor);
}
case neuropod::INT32_TENSOR: {
return njni::createDirectBuffer<int32_t>(env, neuropodTensor);
}
case neuropod::INT64_TENSOR: {
return njni::createDirectBuffer<int64_t>(env, neuropodTensor);
}
case neuropod::STRING_TENSOR: {
// If it is STRING_TENSOR, we would flatten the tensor data and convert it to a string list
// we don't need the buffer to store the data
return env->NewGlobalRef(NULL);
}
default:
throw std::runtime_error("unsupported tensor type: " + njni::tensor_type_to_string(tensorType));
}
}
catch (const std::exception &e)
{
njni::throw_java_exception(env, e.what());
}
return nullptr;
}