int main()

in recipes/joint_training_vox_populi/cpc/Decode.cpp [60:854]


int main(int argc, char** argv) {
  google::InitGoogleLogging(argv[0]);
  google::InstallFailureSignalHandler();
  std::string exec(argv[0]);
  std::vector<std::string> argvs;
  for (int i = 0; i < argc; i++) {
    argvs.emplace_back(argv[i]);
  }
  gflags::SetUsageMessage("Usage: Please refer to https://git.io/JvJuR");
  if (argc <= 1) {
    LOG(FATAL) << gflags::ProgramUsage();
  }

  /* ===================== Parse Options ===================== */
  LOG(INFO) << "Parsing command line flags";
  gflags::ParseCommandLineFlags(&argc, &argv, false);
  auto flagsfile = FLAGS_flagsfile;
  if (!flagsfile.empty()) {
    LOG(INFO) << "Reading flags from file " << flagsfile;
    gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
    // Re-parse command line flags to override values in the flag file.
    gflags::ParseCommandLineFlags(&argc, &argv, false);
  }

  if (!FLAGS_fl_log_level.empty()) {
    fl::Logging::setMaxLoggingLevel(fl::logLevelValue(FLAGS_fl_log_level));
  }
  fl::VerboseLogging::setMaxLoggingLevel(FLAGS_fl_vlog_level);

  /* ===================== Create Network ===================== */
  if (FLAGS_emission_dir.empty() && FLAGS_am.empty()) {
    LOG(FATAL) << "Both flags are empty: `-emission_dir` and `-am`";
  }

  std::shared_ptr<fl::Sequential> network;
  std::shared_ptr<SequenceCriterion> _criterion;
  std::shared_ptr<SequenceCriterion> criterion;
  std::unordered_map<std::string, std::string> cfg;
  std::string version;

  /* Using acoustic model */
  if (!FLAGS_am.empty()) {
    LOG(INFO) << "[Network] Reading acoustic model from " << FLAGS_am;
    af::setDevice(0);
    Serializer::load(FLAGS_am, version, cfg, network, _criterion, criterion);
    network->eval();
    if (version != FL_APP_ASR_VERSION) {
      LOG(WARNING) << "[Network] Model version " << version
                   << " and code version " << FL_APP_ASR_VERSION;
    }
    LOG(INFO) << "[Network] " << network->prettyString();
    if (criterion) {
      criterion->eval();
      LOG(INFO) << "[Criterion] " << criterion->prettyString();
    }
    LOG(INFO) << "[Network] Number of params: " << numTotalParams(network);

    auto flags = cfg.find(kGflags);
    if (flags == cfg.end()) {
      LOG(FATAL) << "[Network] Invalid config loaded from " << FLAGS_am;
    }
    LOG(INFO) << "[Network] Updating flags from config file: " << FLAGS_am;
    gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
  }

  // override with user-specified flags
  gflags::ParseCommandLineFlags(&argc, &argv, false);
  if (!flagsfile.empty()) {
    gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
    // Re-parse command line flags to override values in the flag file.
    gflags::ParseCommandLineFlags(&argc, &argv, false);
  }

  // Only Copy any values from deprecated flags to new flags when deprecated
  // flags are present and corresponding new flags aren't
  handleDeprecatedFlags();

  LOG(INFO) << "Gflags after parsing \n"
            << fl::pkg::speech::serializeGflags("; ");

  /* ===================== Create Dictionary ===================== */
  auto dictPath = FLAGS_tokens;
  if (dictPath.empty() || !fl::lib::fileExists(dictPath)) {
    throw std::runtime_error("Invalid dictionary filepath specified.");
  }
  fl::lib::text::Dictionary tokenDict(dictPath);
  // Setup-specific modifications
  for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
    tokenDict.addEntry("<" + std::to_string(r) + ">");
  }
  // ctc expects the blank label last
  if (FLAGS_criterion2 == kCtcCriterion) {
    tokenDict.addEntry(kBlankToken);
  }
  bool isSeq2seqCrit = FLAGS_criterion == kSeq2SeqTransformerCriterion ||
      FLAGS_criterion == kSeq2SeqRNNCriterion;
  if (isSeq2seqCrit) {
    tokenDict.addEntry(fl::pkg::speech::kEosToken);
    tokenDict.addEntry(fl::lib::text::kPadToken);
  }

  int numClasses = tokenDict.indexSize();
  LOG(INFO) << "Number of classes (network): " << numClasses;

  fl::lib::text::Dictionary wordDict;
  fl::lib::text::LexiconMap lexicon;
  if (!FLAGS_lexicon.empty()) {
    lexicon = fl::lib::text::loadWords(FLAGS_lexicon, FLAGS_maxword);
    wordDict = fl::lib::text::createWordDict(lexicon);
    LOG(INFO) << "Number of words: " << wordDict.indexSize();
  } else {
    if (FLAGS_uselexicon || FLAGS_decodertype == "wrd") {
      LOG(FATAL) << "For lexicon-based beam-search decoder "
                 << "lexicon shouldn't be empty";
    }
  }

  /* =============== Prepare Sharable Decoder Components ============== */
  // Prepare counters
  std::vector<double> sliceWrdDst(FLAGS_nthread_decoder);
  std::vector<double> sliceTknDst(FLAGS_nthread_decoder);
  std::vector<int> sliceNumWords(FLAGS_nthread_decoder, 0);
  std::vector<int> sliceNumTokens(FLAGS_nthread_decoder, 0);
  std::vector<int> sliceNumSamples(FLAGS_nthread_decoder, 0);
  std::vector<double> sliceTime(FLAGS_nthread_decoder, 0);

  // Prepare criterion
  CriterionType criterionType = CriterionType::ASG;
  if (FLAGS_criterion2 == kCtcCriterion) {
    criterionType = CriterionType::CTC;
  } else if (
      FLAGS_criterion == kSeq2SeqRNNCriterion ||
      FLAGS_criterion == kSeq2SeqTransformerCriterion) {
    criterionType = CriterionType::S2S;
  } else if (FLAGS_criterion2 != kAsgCriterion) {
    LOG(FATAL) << "[Decoder] Invalid model type: " << FLAGS_criterion2;
  }

  std::vector<float> transition;
  if (FLAGS_criterion2 == kAsgCriterion) {
    transition = afToVector<float>(criterion->param(0).array());
  }

  // Prepare log writer
  std::mutex hypMutex, refMutex, logMutex;
  std::ofstream hypStream, refStream, logStream;
  if (!FLAGS_sclite.empty()) {
    auto fileName = cleanFilepath(FLAGS_test);
    auto hypPath = pathsConcat(FLAGS_sclite, fileName + ".hyp");
    auto refPath = pathsConcat(FLAGS_sclite, fileName + ".ref");
    auto logPath = pathsConcat(FLAGS_sclite, fileName + ".log");
    hypStream.open(hypPath);
    refStream.open(refPath);
    logStream.open(logPath);
    if (!hypStream.is_open() || !hypStream.good()) {
      LOG(FATAL) << "Error opening hypothesis file: " << hypPath;
    }
    if (!refStream.is_open() || !refStream.good()) {
      LOG(FATAL) << "Error opening reference file: " << refPath;
    }
    if (!logStream.is_open() || !logStream.good()) {
      LOG(FATAL) << "Error opening log file: " << logPath;
    }
  }

  auto writeHyp = [&hypMutex, &hypStream](const std::string& hypStr) {
    std::lock_guard<std::mutex> lock(hypMutex);
    hypStream << hypStr;
  };
  auto writeRef = [&refMutex, &refStream](const std::string& refStr) {
    std::lock_guard<std::mutex> lock(refMutex);
    refStream << refStr;
  };
  auto writeLog = [&logMutex, &logStream](const std::string& logStr) {
    std::lock_guard<std::mutex> lock(logMutex);
    logStream << logStr;
  };

  // Build Language Model
  int unkWordIdx = -1;

  fl::lib::text::Dictionary usrDict = tokenDict;
  if (!FLAGS_lm.empty() && FLAGS_decodertype == "wrd") {
    usrDict = wordDict;
    unkWordIdx = wordDict.getIndex(kUnkToken);
  }

  std::shared_ptr<fl::lib::text::LM> lm =
      std::make_shared<fl::lib::text::ZeroLM>();
  if (!FLAGS_lm.empty()) {
    if (FLAGS_lmtype == "kenlm") {
      lm = std::make_shared<fl::lib::text::KenLM>(FLAGS_lm, usrDict);
      if (!lm) {
        LOG(FATAL) << "[LM constructing] Failed to load LM: " << FLAGS_lm;
      }
    } else if (FLAGS_lmtype == "convlm") {
      af::setDevice(0);
      LOG(INFO) << "[ConvLM]: Loading LM from " << FLAGS_lm;
      std::shared_ptr<fl::Module> convLmModel;
      std::string convlmVersion;
      Serializer::load(FLAGS_lm, convlmVersion, convLmModel);
      if (convlmVersion != FL_APP_ASR_VERSION) {
        LOG(WARNING) << "[Convlm] Model version " << convlmVersion
                     << " and code version " << FL_APP_ASR_VERSION;
      }
      convLmModel->eval();

      auto getConvLmScoreFunc = buildGetConvLmScoreFunction(convLmModel);
      lm = std::make_shared<fl::lib::text::ConvLM>(
          getConvLmScoreFunc,
          FLAGS_lm_vocab,
          usrDict,
          FLAGS_lm_memory,
          FLAGS_beamsize);
    } else {
      LOG(FATAL) << "[LM constructing] Invalid LM Type: " << FLAGS_lmtype;
    }
  }
  LOG(INFO) << "[Decoder] LM constructed.";

  // Build Trie
  int blankIdx =
      FLAGS_criterion2 == kCtcCriterion ? tokenDict.getIndex(kBlankToken) : -1;
  int silIdx = -1;
  if (FLAGS_wordseparator != "") {
    silIdx = tokenDict.getIndex(FLAGS_wordseparator);
  }
  std::shared_ptr<fl::lib::text::Trie> trie = buildTrie(
      FLAGS_decodertype,
      FLAGS_uselexicon,
      lm,
      FLAGS_smearing,
      tokenDict,
      lexicon,
      wordDict,
      silIdx,
      FLAGS_replabel);
  LOG(INFO) << "[Decoder] Trie smeared.\n";

  /* ===================== Create Dataset ===================== */
  fl::lib::audio::FeatureParams featParams(
      FLAGS_samplerate,
      FLAGS_framesizems,
      FLAGS_framestridems,
      FLAGS_filterbanks,
      FLAGS_lowfreqfilterbank,
      FLAGS_highfreqfilterbank,
      FLAGS_mfcccoeffs,
      kLifterParam /* lifterparam */,
      FLAGS_devwin /* delta window */,
      FLAGS_devwin /* delta-delta window */);
  featParams.useEnergy = false;
  featParams.usePower = false;
  featParams.zeroMeanFrame = false;
  FeatureType featType =
      getFeatureType(FLAGS_features_type, FLAGS_channels, featParams).second;

  TargetGenerationConfig targetGenConfig(
      FLAGS_wordseparator,
      FLAGS_sampletarget,
      FLAGS_criterion2,
      FLAGS_surround,
      isSeq2seqCrit,
      FLAGS_replabel,
      true /* skip unk */,
      FLAGS_usewordpiece /* fallback2LetterWordSepLeft */,
      !FLAGS_usewordpiece /* fallback2LetterWordSepLeft */);

  auto inputTransform = inputFeatures(
      featParams,
      featType,
      {FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx},
      /*sfxConf=*/{});
  auto targetTransform = targetFeatures(tokenDict, lexicon, targetGenConfig);
  auto wordTransform = wordFeatures(wordDict);
  int targetpadVal = isSeq2seqCrit
      ? tokenDict.getIndex(fl::lib::text::kPadToken)
      : kTargetPadValue;
  int wordpadVal = wordDict.getIndex(kUnkToken);

  std::vector<std::string> testSplits = fl::lib::split(",", FLAGS_test, true);
  auto ds = createDataset(
      testSplits,
      FLAGS_datadir,
      1 /* batchsize */,
      inputTransform,
      targetTransform,
      wordTransform,
      std::make_tuple(0, targetpadVal, wordpadVal),
      0 /* worldrank */,
      1 /* worldsize */);

  int nSamples = ds->size();
  if (FLAGS_maxload > 0) {
    nSamples = std::min(nSamples, FLAGS_maxload);
  }
  LOG(INFO) << "[Dataset] Dataset loaded, with " << nSamples << " samples.";

  /* ===================== AM Forwarding ===================== */
  using EmissionQueue = fl::lib::ProducerConsumerQueue<EmissionTargetPair>;
  EmissionQueue emissionQueue(FLAGS_emission_queue_size);

  auto runAmForward = [&network,
                       &criterion,
                       &nSamples,
                       &ds,
                       &tokenDict,
                       &wordDict,
                       &emissionQueue,
                       &isSeq2seqCrit](int tid) {
    // Initialize AM
    af::setDevice(tid);
    std::shared_ptr<fl::Sequential> localNetwork = network;
    std::shared_ptr<SequenceCriterion> localCriterion = criterion;
    std::shared_ptr<SequenceCriterion> _localCriterion;
    if (tid != 0) {
      std::unordered_map<std::string, std::string> dummyCfg;
      std::string dummyVersion;
      Serializer::load(
          FLAGS_am,
          dummyVersion,
          dummyCfg,
          localNetwork,
          _localCriterion,
          localNetwork);
      localNetwork->eval();
      localCriterion->eval();
    }

    std::vector<int64_t> selectedIds;
    for (int64_t i = tid; i < nSamples; i += FLAGS_nthread_decoder_am_forward) {
      selectedIds.emplace_back(i);
    }
    std::shared_ptr<fl::Dataset> localDs =
        std::make_shared<fl::ResampleDataset>(ds, selectedIds);
    localDs = std::make_shared<fl::PrefetchDataset>(
        localDs, FLAGS_nthread, FLAGS_nthread);

    for (auto& sample : *localDs) {
      auto sampleId = readSampleIds(sample[kSampleIdx]).front();

      /* 2. Load Targets */
      TargetUnit targetUnit;
      auto tokenTarget = afToVector<int>(sample[kTargetIdx]);
      auto wordTarget = afToVector<int>(sample[kWordIdx]);
      // TODO: we will reform the dataset so that the loaded word
      // targets are strings already
      std::vector<std::string> wordTargetStr;
      if (FLAGS_uselexicon) {
        wordTargetStr = wrdIdx2Wrd(wordTarget, wordDict);
      } else {
        auto letterTarget = tknTarget2Ltr(
            tokenTarget,
            tokenDict,
            FLAGS_criterion2,
            FLAGS_surround,
            isSeq2seqCrit,
            FLAGS_replabel,
            FLAGS_usewordpiece,
            FLAGS_wordseparator);
        wordTargetStr = tkn2Wrd(letterTarget, FLAGS_wordseparator);
      }

      targetUnit.wordTargetStr = wordTargetStr;
      targetUnit.tokenTarget = tokenTarget;

      /* 3. Load Emissions */
      EmissionUnit emissionUnit;
      if (FLAGS_emission_dir.empty()) {
        int idx = 0;
        auto enc_out = localNetwork->module(idx++)
                           ->forward({fl::input(sample[kInputIdx])})
                           .front();
        enc_out = localNetwork->module(idx++)->forward({enc_out}).front();
        enc_out = localNetwork->module(idx++)->forward({enc_out}).front();
        enc_out = w2l::cpc::forwardSequentialModuleWithPadMask(
            enc_out, localNetwork->module(idx++), sample[kDurationIdx]);
        auto rawEmission =
            localNetwork->module(idx)->forward({enc_out}).front();
        emissionUnit = EmissionUnit(
            afToVector<float>(rawEmission),
            sampleId,
            rawEmission.dims(1),
            rawEmission.dims(0));
      } else {
        auto cleanTestPath = cleanFilepath(FLAGS_test);
        std::string emissionDir =
            pathsConcat(FLAGS_emission_dir, cleanTestPath);
        std::string savePath = pathsConcat(emissionDir, sampleId + ".bin");
        std::string eVersion;
        Serializer::load(savePath, eVersion, emissionUnit);
      }

      emissionQueue.add({emissionUnit, targetUnit});
    }

    localNetwork.reset(); // AM is only used in running forward pass. So we will
    // free the space of it on GPU or memory.
    // localNetwork.use_count() will be 0 after this call.

    af::deviceGC(); // Explicitly call the Garbage collector.
  };

  /* ===================== Decode ===================== */
  auto runDecoder = [&criterion,
                     &lm,
                     &trie,
                     &silIdx,
                     &blankIdx,
                     &unkWordIdx,
                     &criterionType,
                     &transition,
                     &usrDict,
                     &tokenDict,
                     &wordDict,
                     &emissionQueue,
                     &writeHyp,
                     &writeRef,
                     &writeLog,
                     &sliceWrdDst,
                     &sliceTknDst,
                     &sliceNumWords,
                     &sliceNumTokens,
                     &sliceNumSamples,
                     &sliceTime,
                     &isSeq2seqCrit](int tid) {
    /* 1. Prepare GPU-dependent resources */
    // Note: These 2 GPU-dependent models should be placed on different
    // cards
    // for different threads and nthread_decoder should not be greater
    // than
    // the number of GPUs.
    std::shared_ptr<SequenceCriterion> localCriterion = criterion;
    std::shared_ptr<fl::lib::text::LM> localLm = lm;
    if (FLAGS_lmtype == "convlm" || criterionType == CriterionType::S2S) {
      if (tid >= af::getDeviceCount()) {
        LOG(FATAL)
            << "FLAGS_nthread_decoder exceeds the number of visible GPUs";
      }
      af::setDevice(tid);
    }

    // Make a copy for non-main threads.
    if (tid != 0) {
      if (FLAGS_lmtype == "convlm") {
        LOG(INFO) << "[ConvLM]: Loading LM from " << FLAGS_lm;
        std::shared_ptr<fl::Module> convLmModel;
        std::string convlmVersion;
        Serializer::load(FLAGS_lm, convlmVersion, convLmModel);
        convLmModel->eval();

        auto getConvLmScoreFunc = buildGetConvLmScoreFunction(convLmModel);
        localLm = std::make_shared<fl::lib::text::ConvLM>(
            getConvLmScoreFunc,
            FLAGS_lm_vocab,
            usrDict,
            FLAGS_lm_memory,
            FLAGS_beamsize);
      }

      if (criterionType == CriterionType::S2S) {
        std::shared_ptr<fl::Module> dummyNetwork;
        std::unordered_map<std::string, std::string> dummyCfg;
        Serializer::load(FLAGS_am, dummyCfg, dummyNetwork, localCriterion);
        localCriterion->eval();
      }
    }

    /* 2. Build Decoder */
    std::unique_ptr<fl::lib::text::Decoder> decoder;
    if (FLAGS_decodertype != "wrd" && FLAGS_decodertype != "tkn") {
      LOG(FATAL) << "Unsupported decoder type: " << FLAGS_decodertype;
    }

    if (criterionType == CriterionType::S2S) {
      auto amUpdateFunc = FLAGS_criterion == kSeq2SeqRNNCriterion
          ? buildSeq2SeqRnnAmUpdateFunction(
                localCriterion,
                FLAGS_decoderattnround,
                FLAGS_beamsize,
                FLAGS_attentionthreshold,
                FLAGS_smoothingtemperature)
          : buildSeq2SeqTransformerAmUpdateFunction(
                localCriterion,
                FLAGS_beamsize,
                FLAGS_attentionthreshold,
                FLAGS_smoothingtemperature);
      int eosIdx = tokenDict.getIndex(fl::pkg::speech::kEosToken);

      if (FLAGS_decodertype == "wrd" || FLAGS_uselexicon) {
        decoder.reset(new fl::lib::text::LexiconSeq2SeqDecoder(
            {
                .beamSize = FLAGS_beamsize,
                .beamSizeToken = FLAGS_beamsizetoken,
                .beamThreshold = FLAGS_beamthreshold,
                .lmWeight = FLAGS_lmweight,
                .wordScore = FLAGS_wordscore,
                .eosScore = FLAGS_eosscore,
                .logAdd = FLAGS_logadd,
            },
            trie,
            localLm,
            eosIdx,
            amUpdateFunc,
            FLAGS_maxdecoderoutputlen,
            FLAGS_decodertype == "tkn"));
        LOG(INFO) << "[Decoder] LexiconSeq2Seq decoder with "
                  << FLAGS_decodertype << "-LM loaded in thread: " << tid;
      } else {
        decoder.reset(new fl::lib::text::LexiconFreeSeq2SeqDecoder(
            {
                .beamSize = FLAGS_beamsize,
                .beamSizeToken = FLAGS_beamsizetoken,
                .beamThreshold = FLAGS_beamthreshold,
                .lmWeight = FLAGS_lmweight,
                .eosScore = FLAGS_eosscore,
                .logAdd = FLAGS_logadd,
            },
            localLm,
            eosIdx,
            amUpdateFunc,
            FLAGS_maxdecoderoutputlen));
        LOG(INFO)
            << "[Decoder] LexiconFreeSeq2Seq decoder with token-LM loaded in thread: "
            << tid;
      }
    } else {
      if (FLAGS_decodertype == "wrd" || FLAGS_uselexicon) {
        decoder.reset(new fl::lib::text::LexiconDecoder(
            {.beamSize = FLAGS_beamsize,
             .beamSizeToken = FLAGS_beamsizetoken,
             .beamThreshold = FLAGS_beamthreshold,
             .lmWeight = FLAGS_lmweight,
             .wordScore = FLAGS_wordscore,
             .unkScore = FLAGS_unkscore,
             .silScore = FLAGS_silscore,
             .logAdd = FLAGS_logadd,
             .criterionType = criterionType},
            trie,
            localLm,
            silIdx,
            blankIdx,
            unkWordIdx,
            transition,
            FLAGS_decodertype == "tkn"));
        LOG(INFO) << "[Decoder] Lexicon decoder with " << FLAGS_decodertype
                  << "-LM loaded in thread: " << tid;
      } else {
        decoder.reset(new fl::lib::text::LexiconFreeDecoder(
            {.beamSize = FLAGS_beamsize,
             .beamSizeToken = FLAGS_beamsizetoken,
             .beamThreshold = FLAGS_beamthreshold,
             .lmWeight = FLAGS_lmweight,
             .silScore = FLAGS_silscore,
             .logAdd = FLAGS_logadd,
             .criterionType = criterionType},
            localLm,
            silIdx,
            blankIdx,
            transition));
        LOG(INFO)
            << "[Decoder] Lexicon-free decoder with token-LM loaded in thread: "
            << tid;
      }
    }
    /* 3. Get data and run decoder */
    TestMeters meters;
    EmissionTargetPair emissionTargetPair;
    while (emissionQueue.get(emissionTargetPair)) {
      const auto& emissionUnit = emissionTargetPair.first;
      const auto& targetUnit = emissionTargetPair.second;

      const auto& nFrames = emissionUnit.nFrames;
      const auto& nTokens = emissionUnit.nTokens;
      const auto& emission = emissionUnit.emission;
      const auto& sampleId = emissionUnit.sampleId;
      const auto& wordTarget = targetUnit.wordTargetStr;
      const auto& tokenTarget = targetUnit.tokenTarget;
      // DecodeResult
      meters.timer.reset();
      meters.timer.resume();
      const auto& results = decoder->decode(emission.data(), nFrames, nTokens);
      meters.timer.stop();

      int nTopHyps = FLAGS_isbeamdump ? results.size() : 1;
      for (int i = 0; i < nTopHyps; i++) {
        // Cleanup predictions
        auto rawWordPrediction = results[i].words;
        auto rawTokenPrediction = results[i].tokens;

        auto letterTarget = tknTarget2Ltr(
            tokenTarget,
            tokenDict,
            FLAGS_criterion2,
            FLAGS_surround,
            isSeq2seqCrit,
            FLAGS_replabel,
            FLAGS_usewordpiece,
            FLAGS_wordseparator);
        auto letterPrediction = tknPrediction2Ltr(
            rawTokenPrediction,
            tokenDict,
            FLAGS_criterion2,
            FLAGS_surround,
            isSeq2seqCrit,
            FLAGS_replabel,
            FLAGS_usewordpiece,
            FLAGS_wordseparator);
        std::vector<std::string> wordPrediction;
        if (FLAGS_uselexicon) {
          rawWordPrediction =
              validateIdx(rawWordPrediction, wordDict.getIndex(kUnkToken));
          wordPrediction = wrdIdx2Wrd(rawWordPrediction, wordDict);
        } else {
          wordPrediction = tkn2Wrd(letterPrediction, FLAGS_wordseparator);
        }
        auto wordTargetStr = join(" ", wordTarget);
        auto wordPredictionStr = join(" ", wordPrediction);
        // Normal decoding and computing WER
        if (!FLAGS_isbeamdump) {
          meters.wrdDstSlice.add(wordPrediction, wordTarget);
          meters.tknDstSlice.add(letterPrediction, letterTarget);

          if (!FLAGS_sclite.empty()) {
            std::string suffix = " (" + sampleId + ")\n";
            writeHyp(wordPredictionStr + suffix);
            writeRef(wordTargetStr + suffix);
          }

          if (FLAGS_show) {
            meters.wrdDst.reset();
            meters.tknDst.reset();
            meters.wrdDst.add(wordPrediction, wordTarget);
            meters.tknDst.add(letterPrediction, letterTarget);

            std::stringstream buffer;
            buffer << "|T|: " << wordTargetStr << std::endl;
            buffer << "|P|: " << wordPredictionStr << std::endl;
            if (FLAGS_showletters) {
              buffer << "|t|: " << join(" ", letterTarget) << std::endl;
              buffer << "|p|: " << join(" ", letterPrediction) << std::endl;
            }
            buffer << "[sample: " << sampleId
                   << ", WER: " << meters.wrdDst.errorRate()[0]
                   << "\%, TER: " << meters.tknDst.errorRate()[0]
                   << "\%, slice WER: " << meters.wrdDstSlice.errorRate()[0]
                   << "\%, slice TER: " << meters.tknDstSlice.errorRate()[0]
                   << "\%, decoded samples (thread " << tid
                   << "): " << sliceNumSamples[tid] + 1 << "]" << std::endl;

            std::cout << buffer.str();

            if (!FLAGS_sclite.empty()) {
              writeLog(buffer.str());
            }
          }

          // Update conters
          sliceNumWords[tid] += wordTarget.size();
          sliceNumTokens[tid] += letterTarget.size();
          sliceTime[tid] += meters.timer.value();
          sliceNumSamples[tid] += 1;
        }
        // Beam Dump
        else {
          meters.wrdDst.reset();
          meters.wrdDst.add(wordPrediction, wordTarget);
          auto wer = meters.wrdDst.errorRate()[0];

          if (FLAGS_sclite.empty()) {
            LOG(FATAL) << "FLAGS_sclite is empty, nowhere to dump the beam.";
          }

          auto score = results[i].score;
          auto amScore = results[i].amScore;
          auto lmScore = results[i].lmScore;
          auto outString = sampleId + " | " + std::to_string(score) + " | " +
              std::to_string(amScore) + " | " + std::to_string(lmScore) +
              " | " + std::to_string(wer) + " | " + wordPredictionStr + "\n";
          writeHyp(outString);
        }
      }
    }
    sliceWrdDst[tid] = meters.wrdDstSlice.value()[0];
    sliceTknDst[tid] = meters.tknDstSlice.value()[0];
  };

  /* ===================== Spread threades ===================== */
  if (FLAGS_nthread_decoder_am_forward <= 0) {
    LOG(FATAL) << "FLAGS_nthread_decoder_am_forward ("
               << FLAGS_nthread_decoder_am_forward << ") need to be positive ";
  }
  if (FLAGS_nthread_decoder <= 0) {
    LOG(FATAL) << "FLAGS_nthread_decoder (" << FLAGS_nthread_decoder
               << ") need to be positive ";
  }

  auto startThreadsAndJoin = [&runAmForward, &runDecoder, &emissionQueue](
                                 int nAmThreads, int nDecoderThreads) {
    // TODO possibly try catch for futures to proper logging of all errors
    // https://github.com/facebookresearch/gtn/blob/master/gtn/parallel/parallel_map.h#L154

    // We have to run AM forwarding and decoding in sequential to avoid GPU
    // OOM with two large neural nets.
    if (FLAGS_lmtype == "convlm") {
      // 1. AM forwarding
      {
        std::vector<std::future<void>> futs(nAmThreads);
        fl::ThreadPool threadPool(nAmThreads);
        for (int i = 0; i < nAmThreads; i++) {
          futs[i] = threadPool.enqueue(runAmForward, i);
        }
        for (int i = 0; i < nAmThreads; i++) {
          futs[i].get();
        }
        emissionQueue.finishAdding();
      }
      // 2. Decoding
      {
        std::vector<std::future<void>> futs(nDecoderThreads);
        fl::ThreadPool threadPool(nDecoderThreads);
        for (int i = 0; i < nDecoderThreads; i++) {
          futs[i] = threadPool.enqueue(runDecoder, i);
        }
        for (int i = 0; i < nDecoderThreads; i++) {
          futs[i].get();
        }
      }
    }
    // Non-convLM decoding. AM forwarding and decoding can be run in parallel.
    else {
      std::vector<std::future<void>> futs(nAmThreads + nDecoderThreads);
      fl::ThreadPool threadPool(nAmThreads + nDecoderThreads);
      // AM forwarding threads
      for (int i = 0; i < nAmThreads; i++) {
        futs[i] = threadPool.enqueue(runAmForward, i);
      }
      // Decoding threads
      for (int i = 0; i < nDecoderThreads; i++) {
        futs[i + nAmThreads] = threadPool.enqueue(runDecoder, i);
      }

      for (int i = 0; i < nAmThreads; i++) {
        futs[i].get();
      }
      emissionQueue.finishAdding();
      for (int i = nAmThreads; i < nAmThreads + nDecoderThreads; i++) {
        futs[i].get();
      }
    }
  };
  auto timer = fl::TimeMeter();
  timer.resume();
  startThreadsAndJoin(FLAGS_nthread_decoder_am_forward, FLAGS_nthread_decoder);
  timer.stop();

  /* Compute statistics */
  int totalTokens = 0, totalWords = 0, totalSamples = 0;
  for (int i = 0; i < FLAGS_nthread_decoder; i++) {
    totalTokens += sliceNumTokens[i];
    totalWords += sliceNumWords[i];
    totalSamples += sliceNumSamples[i];
  }
  double totalWer = 0, totalTkn = 0, totalTime = 0;
  for (int i = 0; i < FLAGS_nthread_decoder; i++) {
    totalWer += sliceWrdDst[i];
    totalTkn += sliceTknDst[i];
    totalTime += sliceTime[i];
  }
  if (totalWer > 0 && totalWords == 0) {
    totalWer = std::numeric_limits<double>::infinity();
  } else {
    totalWer = totalWords > 0 ? totalWer / totalWords * 100. : 0.0;
  }
  if (totalTkn > 0 && totalTokens == 0) {
    totalTkn = std::numeric_limits<double>::infinity();
  } else {
    totalTkn = totalTokens > 0 ? totalTkn / totalTokens * 100. : 0.0;
  }
  std::stringstream buffer;
  buffer << "------\n";
  buffer << "[Decode " << FLAGS_test << " (" << totalSamples << " samples) in "
         << timer.value() << "s (actual decoding time " << std::setprecision(3)
         << totalTime / totalSamples
         << "s/sample) -- WER: " << std::setprecision(6) << totalWer
         << "\%, TER: " << totalTkn << "\%]" << std::endl;
  LOG(INFO) << buffer.str();
  if (!FLAGS_sclite.empty()) {
    writeLog(buffer.str());
    hypStream.close();
    refStream.close();
    logStream.close();
  }
  return 0;
}