JNIEXPORT jlong JNICALL Java_com_uber_neuropod_NeuropodTensorAllocator_nativeAllocate()

in source/neuropod/bindings/java/src/main/native/com_uber_neuropod_NeuropodTensorAllocator.cc [50:99]


JNIEXPORT jlong JNICALL Java_com_uber_neuropod_NeuropodTensorAllocator_nativeAllocate(
    JNIEnv *env, jclass /* unused */, jlongArray dims, jint typeNumber, jobject buffer, jlong handle)
{
    try
    {
        auto allocator = *reinterpret_cast<std::shared_ptr<neuropod::NeuropodTensorAllocator> *>(handle);

        // Prepare shape
        jsize                size = env->GetArrayLength(dims);
        jlong *              arr  = env->GetLongArrayElements(dims, nullptr);
        std::vector<int64_t> shapes(arr, arr + size);
        env->ReleaseLongArrayElements(dims, arr, JNI_ABORT);
        // Prepare Buffer
        auto                    globalBufferRef = env->NewGlobalRef(buffer);
        auto                    bufferAddress   = env->GetDirectBufferAddress(buffer);
        const neuropod::Deleter deleter         = [globalBufferRef, env](void *unused) mutable {
            env->DeleteGlobalRef(globalBufferRef);
        };
        std::shared_ptr<neuropod::NeuropodValue> tensor;
        switch (static_cast<neuropod::TensorType>(typeNumber))
        {
        case neuropod::INT32_TENSOR: {
            tensor =
                allocator->tensor_from_memory<int32_t>(shapes, reinterpret_cast<int32_t *>(bufferAddress), deleter);
            break;
        }
        case neuropod::INT64_TENSOR: {
            tensor =
                allocator->tensor_from_memory<int64_t>(shapes, reinterpret_cast<int64_t *>(bufferAddress), deleter);
            break;
        }
        case neuropod::FLOAT_TENSOR: {
            tensor = allocator->tensor_from_memory<float>(shapes, reinterpret_cast<float *>(bufferAddress), deleter);
            break;
        }
        case neuropod::DOUBLE_TENSOR: {
            tensor = allocator->tensor_from_memory<double>(shapes, reinterpret_cast<double *>(bufferAddress), deleter);
            break;
        }
        default:
            throw std::runtime_error("unsupported tensor type");
        }
        return reinterpret_cast<jlong>(njni::toHeap(tensor));
    }
    catch (const std::exception &e)
    {
        njni::throw_java_exception(env, e.what());
    }
    return reinterpret_cast<jlong>(nullptr);
}