inference/wasm/bindings/service_bindings.cpp (76 lines of code) (raw):

/* * Bindings for Service class */ #include <emscripten/bind.h> #include "service.h" using namespace emscripten; using BlockingService = marian::bergamot::BlockingService; using TranslationModel = marian::bergamot::TranslationModel; using AlignedMemory = marian::bergamot::AlignedMemory; using MemoryBundle = marian::bergamot::MemoryBundle; val getByteArrayView(AlignedMemory& alignedMemory) { return val(typed_memory_view(alignedMemory.size(), alignedMemory.as<char>())); } EMSCRIPTEN_BINDINGS(aligned_memory) { class_<AlignedMemory>("AlignedMemory") .constructor<std::size_t, std::size_t>() .function("size", &AlignedMemory::size) .function("getByteArrayView", &getByteArrayView); register_vector<AlignedMemory*>("AlignedMemoryList"); } // When source and target vocab files are same, only one memory object is passed from JS to // avoid allocating memory twice for the same file. However, the constructor of the Service // class still expects 2 entries in this case, where each entry has the shared ownership of the // same AlignedMemory object. This function prepares these smart pointer based AlignedMemory objects // for unique AlignedMemory objects passed from JS. std::vector<std::shared_ptr<AlignedMemory>> prepareVocabsSmartMemories(std::vector<AlignedMemory*>& vocabsMemories) { auto sourceVocabMemory = std::make_shared<AlignedMemory>(std::move(*(vocabsMemories[0]))); std::vector<std::shared_ptr<AlignedMemory>> vocabsSmartMemories; vocabsSmartMemories.push_back(sourceVocabMemory); if (vocabsMemories.size() == 2) { auto targetVocabMemory = std::make_shared<AlignedMemory>(std::move(*(vocabsMemories[1]))); vocabsSmartMemories.push_back(std::move(targetVocabMemory)); } else { vocabsSmartMemories.push_back(sourceVocabMemory); } return vocabsSmartMemories; } MemoryBundle prepareMemoryBundle(AlignedMemory* modelMemory, AlignedMemory* shortlistMemory, std::vector<AlignedMemory*> uniqueVocabsMemories, AlignedMemory* qualityEstimatorMemory) { MemoryBundle memoryBundle; memoryBundle.models.emplace_back(std::move(*modelMemory)); memoryBundle.shortlist = std::move(*shortlistMemory); memoryBundle.vocabs = std::move(prepareVocabsSmartMemories(uniqueVocabsMemories)); if (qualityEstimatorMemory != nullptr) { memoryBundle.qualityEstimatorMemory = std::move(*qualityEstimatorMemory); } return memoryBundle; } // This allows only shared_ptrs to be operational in JavaScript, according to emscripten. // https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#smart-pointers std::shared_ptr<TranslationModel> TranslationModelFactory( const std::string& sourceLanguage, const std::string& targetLanguage, const std::string& config, AlignedMemory* model, AlignedMemory* shortlist, std::vector<AlignedMemory*> vocabs, AlignedMemory* qualityEstimator ) { MemoryBundle memoryBundle = prepareMemoryBundle(model, shortlist, vocabs, qualityEstimator); std::shared_ptr<TranslationModel> translationModel = std::make_shared<TranslationModel>(config, std::move(memoryBundle)); translationModel->registerSourceLanguage(sourceLanguage); translationModel->registerTargetLanguage(targetLanguage); return translationModel; } EMSCRIPTEN_BINDINGS(translation_model) { class_<TranslationModel>("TranslationModel") .smart_ptr_constructor("TranslationModel", &TranslationModelFactory, allow_raw_pointers()); } EMSCRIPTEN_BINDINGS(blocking_service_config) { value_object<BlockingService::Config>("BlockingServiceConfig") .field("cacheSize", &BlockingService::Config::cacheSize); } std::shared_ptr<BlockingService> BlockingServiceFactory(const BlockingService::Config& config) { auto copy = config; copy.logger.level = "critical"; return std::make_shared<BlockingService>(copy); } EMSCRIPTEN_BINDINGS(blocking_service) { class_<BlockingService>("BlockingService") .smart_ptr_constructor("BlockingService", &BlockingServiceFactory) .function("translate", &BlockingService::translateMultiple) .function("translateViaPivoting", &BlockingService::pivotMultiple); register_vector<std::string>("VectorString"); }