JNIEXPORT jobject JNICALL Java_com_uber_neuropod_NeuropodTensor_nativeGetBuffer()

in source/neuropod/bindings/java/src/main/native/com_uber_neuropod_NeuropodTensor.cc [36:73]


JNIEXPORT jobject JNICALL Java_com_uber_neuropod_NeuropodTensor_nativeGetBuffer(JNIEnv *env,
                                                                                jclass /*unused*/,
                                                                                jlong nativeHandle)
{
    try
    {
        auto neuropodTensor =
            (*reinterpret_cast<std::shared_ptr<neuropod::NeuropodValue> *>(nativeHandle))->as_tensor();
        auto tensorType = neuropodTensor->get_tensor_type();
        switch (tensorType)
        {
        case neuropod::FLOAT_TENSOR: {
            return njni::createDirectBuffer<float>(env, neuropodTensor);
        }
        case neuropod::DOUBLE_TENSOR: {
            return njni::createDirectBuffer<double>(env, neuropodTensor);
        }
        case neuropod::INT32_TENSOR: {
            return njni::createDirectBuffer<int32_t>(env, neuropodTensor);
        }
        case neuropod::INT64_TENSOR: {
            return njni::createDirectBuffer<int64_t>(env, neuropodTensor);
        }
        case neuropod::STRING_TENSOR: {
            // If it is STRING_TENSOR, we would flatten the tensor data and convert it to a string list
            // we don't need the buffer to store the data
            return env->NewGlobalRef(NULL);
        }
        default:
            throw std::runtime_error("unsupported tensor type: " + njni::tensor_type_to_string(tensorType));
        }
    }
    catch (const std::exception &e)
    {
        njni::throw_java_exception(env, e.what());
    }
    return nullptr;
}