recipes/local_prior_match/Decode_length_lpm.cpp (66 lines of code) (raw):

/** * Copyright (c) Facebook, Inc. and its affiliates. * All rights reserved. * * This source code is licensed under the MIT-style license found in the * LICENSE file in the root directory of this source tree. */ #include <flashlight/flashlight.h> #include <gflags/gflags.h> #include <glog/logging.h> #include <fstream> #include <iostream> #include "common/Defines.h" #include "common/FlashlightUtils.h" #include "common/Transforms.h" #include "criterion/criterion.h" #include "module/module.h" #include "runtime/runtime.h" using namespace w2l; int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); std::string exec(argv[0]); gflags::SetUsageMessage( "Usage: \n " + exec + " [model] [dataset] [outputfile]"); if (argc <= 3) { LOG(FATAL) << gflags::ProgramUsage(); } std::string reloadpath = argv[1]; std::string dataset = argv[2]; std::string outputfile = argv[3]; std::unordered_map<std::string, std::string> cfg; std::shared_ptr<fl::Module> network; std::shared_ptr<SequenceCriterion> criterion; W2lSerializer::load(reloadpath, cfg, network, criterion); auto flags = cfg.find(kGflags); if (flags == cfg.end()) { LOG(FATAL) << "Invalid config loaded from " << reloadpath; } LOG(INFO) << "Reading flags from config file " << reloadpath; gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true); LOG(INFO) << "Parsing command line flags"; gflags::ParseCommandLineFlags(&argc, &argv, false); LOG(INFO) << "Gflags after parsing \n" << serializeGflags("; "); Dictionary dict(pathsConcat(FLAGS_tokensdir, FLAGS_tokens)); if (FLAGS_eostoken) { dict.addEntry(kEosToken); } LOG(INFO) << "Number of classes (network) = " << dict.indexSize(); LOG(INFO) << "[network] " << network->prettyString(); af::setSeed(FLAGS_seed); DictionaryMap dicts; dicts.insert({kTargetIdx, dict}); auto lexicon = loadWords(FLAGS_lexicon, FLAGS_maxword); auto testset = createDataset(dataset, dicts, lexicon, 1, 0, 1); network->eval(); criterion->eval(); std::ofstream out; out.open(outputfile); for (auto& sample : *testset) { auto sampleId = readSampleIds(sample[kSampleIdx]).front(); auto output = network->forward({fl::input(sample[kInputIdx])}).front(); auto viterbipathArr = criterion->viterbiPath(output.array()); auto viterbipath = afToVector<int>(viterbipathArr); remapLabels(viterbipath, dict); if (viterbipath.size() == 0) { continue; } // assume "reflen1" is not a valid word in the lexicon out << sampleId << " reflen" << std::to_string(viterbipath.size()) << std::endl; } out.close(); return 0; }