in jni/src/faiss_wrapper.cpp [112:170]
void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ,
jbyteArray templateIndexJ, jobject parametersJ) {
if (idsJ == nullptr) {
throw std::runtime_error("IDs cannot be null");
}
if (vectorsJ == nullptr) {
throw std::runtime_error("Vectors cannot be null");
}
if (indexPathJ == nullptr) {
throw std::runtime_error("Index path cannot be null");
}
if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
omp_set_num_threads(threadCount);
}
jniUtil->DeleteLocalRef(env, parametersJ);
// Read data set
int numVectors = jniUtil->GetJavaObjectArrayLength(env, vectorsJ);
int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ);
if (numIds != numVectors) {
throw std::runtime_error("Number of IDs does not match number of vectors");
}
int dim = jniUtil->GetInnerDimensionOf2dJavaFloatArray(env, vectorsJ);
auto dataset = jniUtil->Convert2dJavaObjectArrayToCppFloatVector(env, vectorsJ, dim);
// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);
faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);
// Create faiss index
std::unique_ptr<faiss::Index> indexWriter;
indexWriter.reset(faiss::read_index(&vectorIoReader, 0));
auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ);
faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get());
idMap.add_with_ids(numVectors, dataset.data(), idVector.data());
// Write the index to disk
std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
faiss::write_index(&idMap, indexPathCpp.c_str());
}