torchaudio/csrc/decoder/bindings/pybind.cpp (172 lines of code) (raw):
#include <torch/extension.h>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;
/**
* Some hackery that lets pybind11 handle shared_ptr<void> (for old LMStatePtr).
* See: https://github.com/pybind/pybind11/issues/820
* PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);
* and inside PYBIND11_MODULE
* py::class_<std::shared_ptr<void>>(m, "encapsulated_data");
*/
namespace {
/**
* A pybind11 "alias type" for abstract class LM, allowing one to subclass LM
* with a custom LM defined purely in Python. For those who don't want to build
* with KenLM, or have their own custom LM implementation.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html
*
* TODO: ensure this works. Last time Jeff tried this there were slicing issues,
* see https://github.com/pybind/pybind11/issues/1546 for workarounds.
* This is low-pri since we assume most people can just build with KenLM.
*/
class PyLM : public LM {
using LM::LM;
// needed for pybind11 or else it won't compile
using LMOutput = std::pair<LMStatePtr, float>;
LMStatePtr start(bool startWithNothing) override {
PYBIND11_OVERLOAD_PURE(LMStatePtr, LM, start, startWithNothing);
}
LMOutput score(const LMStatePtr& state, const int usrTokenIdx) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, score, state, usrTokenIdx);
}
LMOutput finish(const LMStatePtr& state) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, finish, state);
}
};
/**
* Using custom python LMState derived from LMState is not working with
* custom python LM (derived from PyLM) because we need to to custing of LMState
* in score and finish functions to the derived class
* (for example vie obj.__class__ = CustomPyLMSTate) which cause the error
* "TypeError: __class__ assignment: 'CustomPyLMState' deallocator differs
* from 'flashlight.text.decoder._decoder.LMState'"
* details see in https://github.com/pybind/pybind11/issues/1640
* To define custom LM you can introduce map inside LM which maps LMstate to
* additional state info (shared pointers pointing to the same underlying object
* will have the same id in python in functions score and finish)
*
* ```python
* from flashlight.lib.text.decoder import LM
* class MyPyLM(LM):
* mapping_states = dict() # store simple additional int for each state
*
* def __init__(self):
* LM.__init__(self)
*
* def start(self, start_with_nothing):
* state = LMState()
* self.mapping_states[state] = 0
* return state
*
* def score(self, state, index):
* outstate = state.child(index)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -numpy.random.random())
*
* def finish(self, state):
* outstate = state.child(-1)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -1)
*```
*/
void LexiconDecoder_decodeStep(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconDecoder_decode(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void Dictionary_addEntry_0(
Dictionary& dict,
const std::string& entry,
int idx) {
dict.addEntry(entry, idx);
}
void Dictionary_addEntry_1(Dictionary& dict, const std::string& entry) {
dict.addEntry(entry);
}
PYBIND11_MODULE(_torchaudio_decoder, m) {
#ifdef BUILD_CTC_DECODER
py::enum_<SmearingMode>(m, "_SmearingMode")
.value("NONE", SmearingMode::NONE)
.value("MAX", SmearingMode::MAX)
.value("LOGADD", SmearingMode::LOGADD);
py::class_<TrieNode, TrieNodePtr>(m, "_TrieNode")
.def(py::init<int>(), "idx"_a)
.def_readwrite("children", &TrieNode::children)
.def_readwrite("idx", &TrieNode::idx)
.def_readwrite("labels", &TrieNode::labels)
.def_readwrite("scores", &TrieNode::scores)
.def_readwrite("max_score", &TrieNode::maxScore);
py::class_<Trie, TriePtr>(m, "_Trie")
.def(py::init<int, int>(), "max_children"_a, "root_idx"_a)
.def("get_root", &Trie::getRoot)
.def("insert", &Trie::insert, "indices"_a, "label"_a, "score"_a)
.def("search", &Trie::search, "indices"_a)
.def("smear", &Trie::smear, "smear_mode"_a);
py::class_<LM, LMPtr, PyLM>(m, "_LM")
.def(py::init<>())
.def("start", &LM::start, "start_with_nothing"_a)
.def("score", &LM::score, "state"_a, "usr_token_idx"_a)
.def("finish", &LM::finish, "state"_a);
py::class_<LMState, LMStatePtr>(m, "_LMState")
.def(py::init<>())
.def_readwrite("children", &LMState::children)
.def("compare", &LMState::compare, "state"_a)
.def("child", &LMState::child<LMState>, "usr_index"_a);
py::class_<KenLM, KenLMPtr, LM>(m, "_KenLM")
.def(
py::init<const std::string&, const Dictionary&>(),
"path"_a,
"usr_token_dict"_a);
py::class_<ZeroLM, ZeroLMPtr, LM>(m, "_ZeroLM").def(py::init<>());
py::enum_<CriterionType>(m, "_CriterionType")
.value("ASG", CriterionType::ASG)
.value("CTC", CriterionType::CTC);
py::class_<LexiconDecoderOptions>(m, "_LexiconDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"word_score"_a,
"unk_score"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconDecoderOptions::beamSize)
.def_readwrite("beam_size_token", &LexiconDecoderOptions::beamSizeToken)
.def_readwrite("beam_threshold", &LexiconDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconDecoderOptions::lmWeight)
.def_readwrite("word_score", &LexiconDecoderOptions::wordScore)
.def_readwrite("unk_score", &LexiconDecoderOptions::unkScore)
.def_readwrite("sil_score", &LexiconDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);
py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score)
.def_readwrite("amScore", &DecodeResult::amScore)
.def_readwrite("lmScore", &DecodeResult::lmScore)
.def_readwrite("words", &DecodeResult::words)
.def_readwrite("tokens", &DecodeResult::tokens);
// NB: `decode` and `decodeStep` expect raw emissions pointers.
py::class_<LexiconDecoder>(m, "_LexiconDecoder")
.def(py::init<
LexiconDecoderOptions,
const TriePtr,
const LMPtr,
const int,
const int,
const int,
const std::vector<float>&,
const bool>())
.def("decode_begin", &LexiconDecoder::decodeBegin)
.def(
"decode_step",
&LexiconDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconDecoder::decodeEnd)
.def("decode", &LexiconDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconDecoder::getBestHypothesis,
"look_back"_a = 0)
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a)
.def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize)
.def("index_size", &Dictionary::indexSize)
.def("add_entry", &Dictionary_addEntry_0, "entry"_a, "idx"_a)
.def("add_entry", &Dictionary_addEntry_1, "entry"_a)
.def("get_entry", &Dictionary::getEntry, "idx"_a)
.def("set_default_index", &Dictionary::setDefaultIndex, "idx"_a)
.def("get_index", &Dictionary::getIndex, "entry"_a)
.def("contains", &Dictionary::contains, "entry"_a)
.def("is_contiguous", &Dictionary::isContiguous)
.def(
"map_entries_to_indices",
&Dictionary::mapEntriesToIndices,
"entries"_a)
.def(
"map_indices_to_entries",
&Dictionary::mapIndicesToEntries,
"indices"_a);
m.def("_create_word_dict", &createWordDict, "lexicon"_a);
m.def("_load_words", &loadWords, "filename"_a, "max_words"_a = -1);
#endif
}
} // namespace