in recipes/streaming_convnets/tools/StreamingTDSModelConverter.cpp [140:376]
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
/* ===================== Parse Options ===================== */
LOG(INFO) << "Parsing command line flags";
gflags::ParseCommandLineFlags(&argc, &argv, false);
/* ===================== Create Network ===================== */
std::shared_ptr<fl::Module> network;
std::shared_ptr<SequenceCriterion> criterion;
std::unordered_map<std::string, std::string> cfg;
std::string version;
LOG(INFO) << "[Network] Reading acoustic model from " << FLAGS_am;
fl::ext::Serializer::load(FLAGS_am, version, cfg, network, criterion);
network->eval();
criterion->eval();
LOG(INFO) << "[Network] " << network->prettyString();
LOG(INFO) << "[Criterion] " << criterion->prettyString();
LOG(INFO) << "[Network] Number of params: " << numTotalParams(network);
auto flags = cfg.find(kGflags);
if (flags == cfg.end()) {
LOG(FATAL) << "[Network] Invalid config loaded from " << FLAGS_am;
}
LOG(INFO) << "[Network] Updating flags from config file: " << FLAGS_am;
gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
// override with user-specified flags
gflags::ParseCommandLineFlags(&argc, &argv, false);
if (!FLAGS_flagsfile.empty()) {
gflags::ReadFromFlagsFile(FLAGS_flagsfile, argv[0], true);
}
LOG(INFO) << "Gflags after parsing \n" << serializeGflags("; ");
/* ===================== Create Dictionary ===================== */
auto dictPath = fl::lib::pathsConcat(FLAGS_tokensdir, FLAGS_tokens);
if (dictPath.empty() || !fileExists(dictPath)) {
throw std::runtime_error(
"Invalid dictionary filepath specified " + dictPath);
}
text::Dictionary tokenDict(dictPath);
for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
tokenDict.addEntry("<" + std::to_string(r) + ">");
}
if (FLAGS_criterion == kCtcCriterion) {
tokenDict.addEntry(kBlankToken);
} else if (FLAGS_criterion != kAsgCriterion) {
LOG(FATAL) << "This script currently support only CTC/ASG criterion";
}
int numTokens = tokenDict.indexSize();
LOG(INFO) << "Number of classes (network): " << numTokens;
int nFeat = 0;
if (FLAGS_mfsc) {
nFeat = FLAGS_filterbanks;
} else {
LOG(FATAL) << "This script currently support only mfsc features";
}
auto lines = getFileContent(fl::lib::pathsConcat(FLAGS_archdir, FLAGS_arch));
auto streamingModule = std::make_shared<streaming::Sequential>();
auto params = network->params();
int curFeatSz = nFeat;
int paramIdx = 0;
int leftPad = -1, rightPad = -1;
for (size_t i = 0; i < lines.size(); ++i) {
auto columns = splitOnWhitespace(lines[i], true);
if (columns.empty()) {
continue;
}
auto layerType = columns[0];
if (layerType == "C2") {
if (columns.size() < 8) {
LOG(FATAL) << "Invalid arch specified for C2";
}
auto conv1d = convertConv1d(
std::stoi(columns[1]) * nFeat,
std::stoi(columns[2]) * nFeat,
std::stoi(columns[3]),
std::stoi(columns[5]),
{leftPad, rightPad},
nFeat,
params[paramIdx],
params[paramIdx + 1]);
streamingModule->add(conv1d);
leftPad = -1;
rightPad = -1;
paramIdx += 2;
curFeatSz = std::stoi(columns[2]) * nFeat;
} else if (layerType == "PD") {
if (columns.size() != 4) {
LOG(FATAL) << "Padding is supported only along time axis";
}
if (!startsWith(lines[i + 1], "C2")) {
LOG(FATAL) << "Padding layer must be followed by conv layer";
}
leftPad = std::stoi(columns[2]);
rightPad = std::stoi(columns[3]);
} else if (layerType == "R") {
streamingModule->add(
std::make_shared<streaming::Relu>(streaming::DataType::FLOAT));
} else if (layerType == "LN") {
if (columns[1] != "1" || columns[2] != "2") {
LOG(FATAL)
<< "Unsupported LayerNorm axis: must be {1, 2} for streaming";
}
auto lyrNorm =
convertLayerNorm(curFeatSz, params[paramIdx], params[paramIdx + 1]);
streamingModule->add(lyrNorm);
paramIdx += 2;
} else if (layerType == "L") {
int outDim = (columns[2] == "NLABEL") ? numTokens : std::stoi(columns[2]);
int inDim = std::stoi(columns[1]);
if (params.size() < paramIdx + 2) {
LOG(FATAL) << "Error serializing Linear module. Not enough parameters.";
}
auto linear =
convertLinear(inDim, outDim, params[paramIdx], params[paramIdx + 1]);
streamingModule->add(linear);
paramIdx += 2;
} else if (layerType == "TDS") {
auto stds = convertTDS(
std::stoi(columns[1]),
std::stoi(columns[2]),
std::stoi(columns[3]),
(columns.size() > 6) ? std::stoi(columns[6]) : -1,
{params.begin() + paramIdx, params.begin() + paramIdx + 10},
(columns.size() > 5) ? std::stoi(columns[5]) : 0);
streamingModule->add(stds);
paramIdx += 10;
} else if (layerType == "V") {
std::cerr << "Skipping View module: " << lines[i] << std::endl;
} else if (layerType == "RO") {
std::cerr << "Skipping Reorder module: " << lines[i] << std::endl;
} else if (layerType == "DO") {
std::cerr << "Skipping Dropout module: " << lines[i] << std::endl;
} else if (layerType == "SAUG") {
std::cerr << "Skipping SpecAugment module: " << lines[i] << std::endl;
} else if (layerType != "#") {
throw std::logic_error("Unrecognized/unparsable line " + lines[i]);
}
}
{
std::string amFilePath =
fl::lib::pathsConcat(FLAGS_outdir, "acoustic_model.bin");
std::ofstream amFile(amFilePath, std::ios::binary);
LOG(INFO) << "Serializing acoustic model to '" << amFilePath << "'";
if (!amFile.is_open()) {
throw std::runtime_error("failed to open file for reading");
}
cereal::BinaryOutputArchive ar(amFile);
ar(streamingModule);
}
{
std::string tokenFilePath =
fl::lib::pathsConcat(FLAGS_outdir, "tokens.txt");
std::ofstream tokenFile(tokenFilePath);
LOG(INFO) << "Writing tokens file to '" << tokenFilePath << "'";
for (int i = 0; i < tokenDict.indexSize(); ++i) {
tokenFile << tokenDict.getEntry(i) << "\n";
}
tokenFile.close();
}
if (FLAGS_criterion == kAsgCriterion) {
if (criterion->params().size() == 0 ||
criterion->param(0).elements() !=
tokenDict.indexSize() * tokenDict.indexSize()) {
throw std::runtime_error("Invalid criterion parameters for ASG");
}
std::string transitionsFilePath =
fl::lib::pathsConcat(FLAGS_outdir, "transitions.bin");
std::ofstream transitionsFile(transitionsFilePath);
LOG(INFO) << "Writing transitions file to '" << transitionsFilePath << "'";
std::vector<float> transitionsVec(criterion->param(0).elements());
criterion->param(0).host(transitionsVec.data());
fl::save(transitionsFile, transitionsVec);
transitionsFile.close();
}
{
std::string featFilePath =
fl::lib::pathsConcat(FLAGS_outdir, "feature_extractor.bin");
std::ofstream featFile(featFilePath, std::ios::binary);
LOG(INFO) << "Serializing feature extraction model to '" << featFilePath
<< "'";
auto featureModule = std::make_shared<streaming::Sequential>();
featureModule->add(std::make_shared<streaming::LogMelFeature>(nFeat));
LOG_IF(FATAL, FLAGS_localnrmlleftctx <= 0)
<< "Local Norm should be used for online inference";
featureModule->add(std::make_shared<streaming::LocalNorm>(
nFeat, FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx));
if (!featFile.is_open()) {
throw std::runtime_error("failed to open file for reading");
}
cereal::BinaryOutputArchive ar(featFile);
ar(featureModule);
}
LOG(INFO) << "verifying serialization ...";
af::array inputArr = af::randu(16, 80);
auto outputArr = network->forward({fl::Variable(inputArr, false)})[0].array();
std::vector<float> outputVec(outputArr.elements());
outputArr.host(outputVec.data());
std::vector<float> inputVec(inputArr.elements());
af::reorder(inputArr, 2, 1, 0).host(inputVec.data());
auto inputState = std::make_shared<streaming::ModuleProcessingState>(1);
std::shared_ptr<streaming::IOBuffer> inputBuffer = inputState->buffer(0);
inputBuffer->write<float>(inputVec.data(), inputVec.size());
streamingModule->start(inputState);
auto outputState = streamingModule->run(inputState);
streamingModule->finish(inputState);
std::shared_ptr<streaming::IOBuffer> outputBuffer = outputState->buffer(0);
LOG_IF(FATAL, outputBuffer->size<float>() != outputVec.size())
<< "[Serialization Error] Incorrect output sizes";
float* outPtr = outputBuffer->data<float>();
for (int i = 0; i < outputBuffer->size<float>(); i++) {
float streamingOut = outPtr[i];
float w2lOut = outputVec[i];
LOG_IF(ERROR, fabs(streamingOut - w2lOut) > 1e-2)
<< "[Serialization Error] Mismatched output w2l:" << w2lOut
<< " vs streaming:" << streamingOut;
}
LOG(INFO) << "Done !";
return 0;
}