int main()

in recipes/joint_training_vox_populi/cpc/Test.cpp [43:443]


int main(int argc, char** argv) {
  google::InitGoogleLogging(argv[0]);
  google::InstallFailureSignalHandler();
  std::string exec(argv[0]);
  std::vector<std::string> argvs;
  for (int i = 0; i < argc; i++) {
    argvs.emplace_back(argv[i]);
  }
  gflags::SetUsageMessage("Usage: Please refer to https://git.io/JvJuR");
  if (argc <= 1) {
    LOG(FATAL) << gflags::ProgramUsage();
  }

  /* ===================== Parse Options ===================== */
  LOG(INFO) << "Parsing command line flags";
  gflags::ParseCommandLineFlags(&argc, &argv, false);
  auto flagsfile = FLAGS_flagsfile;
  if (!flagsfile.empty()) {
    LOG(INFO) << "Reading flags from file " << flagsfile;
    gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
  }

  if (!FLAGS_fl_log_level.empty()) {
    fl::Logging::setMaxLoggingLevel(fl::logLevelValue(FLAGS_fl_log_level));
  }
  fl::VerboseLogging::setMaxLoggingLevel(FLAGS_fl_vlog_level);

  /* ===================== Create Network ===================== */
  std::shared_ptr<fl::Sequential> network;
  std::shared_ptr<SequenceCriterion> _criterion;
  std::shared_ptr<SequenceCriterion> criterion;
  std::unordered_map<std::string, std::string> cfg;
  std::string version;
  LOG(INFO) << "[Network] Reading acoustic model from " << FLAGS_am;
  af::setDevice(0);
  Serializer::load(FLAGS_am, version, cfg, network, _criterion, criterion);
  if (version != FL_APP_ASR_VERSION) {
    LOG(WARNING) << "[Network] Model version " << version
                 << " and code version " << FL_APP_ASR_VERSION;
  }
  network->eval();
  criterion->eval();

  LOG(INFO) << "[Network] " << network->prettyString();
  LOG(INFO) << "[Criterion] " << criterion->prettyString();
  LOG(INFO) << "[Network] Number of params: " << numTotalParams(network);

  auto flags = cfg.find(kGflags);
  if (flags == cfg.end()) {
    LOG(FATAL) << "[Network] Invalid config loaded from " << FLAGS_am;
  }
  LOG(INFO) << "[Network] Updating flags from config file: " << FLAGS_am;
  gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);

  // override with user-specified flags
  gflags::ParseCommandLineFlags(&argc, &argv, false);
  if (!flagsfile.empty()) {
    gflags::ReadFromFlagsFile(flagsfile, argv[0], true);
  }

  // Only Copy any values from deprecated flags to new flags when deprecated
  // flags are present and corresponding new flags aren't
  handleDeprecatedFlags();

  LOG(INFO) << "Gflags after parsing \n" << serializeGflags("; ");

  /* ===================== Create Dictionary ===================== */
  auto dictPath = FLAGS_tokens;
  if (dictPath.empty() || !fl::lib::fileExists(dictPath)) {
    throw std::runtime_error("Invalid dictionary filepath specified.");
  }
  fl::lib::text::Dictionary tokenDict(dictPath);
  // Setup-specific modifications
  for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
    tokenDict.addEntry("<" + std::to_string(r) + ">");
  }
  // ctc expects the blank label last
  if (FLAGS_criterion2 == kCtcCriterion) {
    tokenDict.addEntry(kBlankToken);
  }
  if (FLAGS_eostoken) {
    tokenDict.addEntry(fl::app::asr::kEosToken);
  }

  int numClasses = tokenDict.indexSize();
  LOG(INFO) << "Number of classes (network): " << numClasses;

  fl::lib::text::Dictionary wordDict;
  fl::lib::text::LexiconMap lexicon;
  if (!FLAGS_lexicon.empty()) {
    lexicon = fl::lib::text::loadWords(FLAGS_lexicon, FLAGS_maxword);
    wordDict = fl::lib::text::createWordDict(lexicon);
    LOG(INFO) << "Number of words: " << wordDict.indexSize();
  }

  fl::lib::text::DictionaryMap dicts = {
      {kTargetIdx, tokenDict}, {kWordIdx, wordDict}};

  /* ===================== Create Dataset ===================== */
  fl::lib::audio::FeatureParams featParams(
      FLAGS_samplerate,
      FLAGS_framesizems,
      FLAGS_framestridems,
      FLAGS_filterbanks,
      FLAGS_lowfreqfilterbank,
      FLAGS_highfreqfilterbank,
      FLAGS_mfcccoeffs,
      kLifterParam /* lifterparam */,
      FLAGS_devwin /* delta window */,
      FLAGS_devwin /* delta-delta window */);
  featParams.useEnergy = false;
  featParams.usePower = false;
  featParams.zeroMeanFrame = false;
  FeatureType featType = FeatureType::NONE;
  if (FLAGS_pow) {
    featType = FeatureType::POW_SPECTRUM;
  } else if (FLAGS_mfsc) {
    featType = FeatureType::MFSC;
  } else if (FLAGS_mfcc) {
    featType = FeatureType::MFCC;
  }
  TargetGenerationConfig targetGenConfig(
      FLAGS_wordseparator,
      FLAGS_sampletarget,
      FLAGS_criterion2,
      FLAGS_surround,
      FLAGS_eostoken,
      FLAGS_replabel,
      true /* skip unk */,
      FLAGS_usewordpiece /* fallback2LetterWordSepLeft */,
      !FLAGS_usewordpiece /* fallback2LetterWordSepLeft */);

  auto inputTransform = inputFeatures(
      featParams,
      featType,
      {FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx},
      /*sfxConf=*/{});
  auto targetTransform = targetFeatures(tokenDict, lexicon, targetGenConfig);
  auto wordTransform = wordFeatures(wordDict);
  int targetpadVal = FLAGS_eostoken
      ? tokenDict.getIndex(fl::app::asr::kEosToken)
      : kTargetPadValue;
  int wordpadVal = wordDict.getIndex(fl::lib::text::kUnkToken);

  std::vector<std::string> testSplits = fl::lib::split(",", FLAGS_test, true);
  auto ds = createDataset(
      testSplits,
      FLAGS_datadir,
      1 /* batchsize */,
      inputTransform,
      targetTransform,
      wordTransform,
      std::make_tuple(0, targetpadVal, wordpadVal),
      0 /* worldrank */,
      1 /* worldsize */);

  int nSamples = ds->size();
  if (FLAGS_maxload > 0) {
    nSamples = std::min(nSamples, FLAGS_maxload);
  }
  LOG(INFO) << "[Dataset] Dataset loaded, with " << nSamples << " samples.";

  /* ===================== Test ===================== */
  std::vector<double> sliceWrdDst(FLAGS_nthread_decoder_am_forward);
  std::vector<double> sliceTknDst(FLAGS_nthread_decoder_am_forward);
  std::vector<int> sliceNumWords(FLAGS_nthread_decoder_am_forward, 0);
  std::vector<int> sliceNumTokens(FLAGS_nthread_decoder_am_forward, 0);
  std::vector<int> sliceNumSamples(FLAGS_nthread_decoder_am_forward, 0);
  std::vector<double> sliceTime(FLAGS_nthread_decoder_am_forward, 0);

  auto cleanTestPath = cleanFilepath(FLAGS_test);
  std::string emissionDir;
  if (!FLAGS_emission_dir.empty()) {
    emissionDir = pathsConcat(FLAGS_emission_dir, cleanTestPath);
    fl::lib::dirCreate(emissionDir);
  }

  // Prepare sclite log writer
  std::ofstream hypStream, refStream;
  if (!FLAGS_sclite.empty()) {
    auto hypPath = pathsConcat(FLAGS_sclite, cleanTestPath + ".hyp");
    auto refPath = pathsConcat(FLAGS_sclite, cleanTestPath + ".viterbi.ref");
    hypStream.open(hypPath);
    refStream.open(refPath);
    if (!hypStream.is_open() || !hypStream.good()) {
      LOG(FATAL) << "Error opening hypothesis file: " << hypPath;
    }
    if (!refStream.is_open() || !refStream.good()) {
      LOG(FATAL) << "Error opening reference file: " << refPath;
    }
  }

  std::mutex hypMutex, refMutex;
  auto writeHyp = [&hypMutex, &hypStream](const std::string& hypStr) {
    std::lock_guard<std::mutex> lock(hypMutex);
    hypStream << hypStr;
  };
  auto writeRef = [&refMutex, &refStream](const std::string& refStr) {
    std::lock_guard<std::mutex> lock(refMutex);
    refStream << refStr;
  };

  // Run test
  auto run = [&network,
              &criterion,
              &nSamples,
              &ds,
              &tokenDict,
              &wordDict,
              &writeHyp,
              &writeRef,
              &emissionDir,
              &sliceWrdDst,
              &sliceTknDst,
              &sliceNumWords,
              &sliceNumTokens,
              &sliceNumSamples,
              &sliceTime](int tid) {
    // Initialize AM
    af::setDevice(tid);
    std::shared_ptr<fl::Sequential> localNetwork = network;
    std::shared_ptr<SequenceCriterion> _localCriterion;
    std::shared_ptr<SequenceCriterion> localCriterion = criterion;
    if (tid != 0) {
      std::unordered_map<std::string, std::string> dummyCfg;
      std::string dummyVersion;
      Serializer::load(
          FLAGS_am,
          dummyVersion,
          dummyCfg,
          localNetwork,
          _localCriterion,
          localCriterion);
      localNetwork->eval();
      localCriterion->eval();
    }

    std::vector<int64_t> selectedIds;
    for (int64_t i = tid; i < nSamples; i += FLAGS_nthread_decoder_am_forward) {
      selectedIds.emplace_back(i);
    }
    std::shared_ptr<fl::Dataset> localDs =
        std::make_shared<fl::ResampleDataset>(ds, selectedIds);
    localDs = std::make_shared<fl::PrefetchDataset>(
        localDs, FLAGS_nthread, FLAGS_nthread);

    TestMeters meters;
    meters.timer.resume();
    int cnt = 0;
    for (auto& sample : *localDs) {
      int idx = 0;
      auto enc_out = localNetwork->module(idx++)
                         ->forward({fl::input(sample[kInputIdx])})
                         .front();
      enc_out = localNetwork->module(idx++)->forward({enc_out}).front();
      enc_out = localNetwork->module(idx++)->forward({enc_out}).front();
      enc_out = w2l::forwardSequentialModuleWithPadMaskForCPC(
          enc_out, localNetwork->module(idx++), sample[kDurationIdx]);
      auto rawEmission = localNetwork->module(idx)->forward({enc_out}).front();
      auto emission = afToVector<float>(rawEmission);
      auto tokenTarget = afToVector<int>(sample[kTargetIdx]);
      auto wordTarget = afToVector<int>(sample[kWordIdx]);
      auto sampleId = readSampleIds(sample[kSampleIdx]).front();

      auto letterTarget = tknTarget2Ltr(
          tokenTarget,
          tokenDict,
          FLAGS_criterion2,
          FLAGS_surround,
          FLAGS_eostoken,
          FLAGS_replabel,
          FLAGS_usewordpiece,
          FLAGS_wordseparator);
      std::vector<std::string> wordTargetStr;
      if (FLAGS_uselexicon) {
        wordTargetStr = wrdIdx2Wrd(wordTarget, wordDict);
      } else {
        wordTargetStr = tkn2Wrd(letterTarget, FLAGS_wordseparator);
      }

      // Tokens
      auto tokenPrediction =
          afToVector<int>(localCriterion->viterbiPath(rawEmission.array()));
      auto letterPrediction = tknPrediction2Ltr(
          tokenPrediction,
          tokenDict,
          FLAGS_criterion2,
          FLAGS_surround,
          FLAGS_eostoken,
          FLAGS_replabel,
          FLAGS_usewordpiece,
          FLAGS_wordseparator);

      meters.tknDstSlice.add(letterPrediction, letterTarget);

      // Words
      std::vector<std::string> wrdPredictionStr =
          tkn2Wrd(letterPrediction, FLAGS_wordseparator);
      meters.wrdDstSlice.add(wrdPredictionStr, wordTargetStr);

      if (!FLAGS_sclite.empty()) {
        writeRef(join(" ", wordTargetStr) + " (" + sampleId + ")\n");
        writeHyp(join(" ", wrdPredictionStr) + " (" + sampleId + ")\n");
      }

      if (FLAGS_show) {
        meters.tknDst.reset();
        meters.wrdDst.reset();
        meters.tknDst.add(letterPrediction, letterTarget);
        meters.wrdDst.add(wrdPredictionStr, wordTargetStr);

        std::cout << "|T|: " << join(" ", letterTarget) << std::endl;
        std::cout << "|P|: " << join(" ", letterPrediction) << std::endl;
        std::cout << "[sample: " << sampleId
                  << ", WER: " << meters.wrdDst.errorRate()[0]
                  << "\%, TER: " << meters.tknDst.errorRate()[0]
                  << "\%, total WER: " << meters.wrdDstSlice.errorRate()[0]
                  << "\%, total TER: " << meters.tknDstSlice.errorRate()[0]
                  << "\%, progress (thread " << tid << "): "
                  << static_cast<float>(++cnt) / selectedIds.size() * 100
                  << "\%]" << std::endl;
      }

      /* Save emission and targets */
      int nTokens = rawEmission.dims(0);
      int nFrames = rawEmission.dims(1);
      EmissionUnit emissionUnit(emission, sampleId, nFrames, nTokens);

      // Update counters
      sliceNumWords[tid] += wordTarget.size();
      sliceNumTokens[tid] += letterTarget.size();
      sliceNumSamples[tid]++;

      if (!emissionDir.empty()) {
        std::string savePath = pathsConcat(emissionDir, sampleId + ".bin");
        Serializer::save(savePath, FL_APP_ASR_VERSION, emissionUnit);
      }
    }

    meters.timer.stop();

    sliceWrdDst[tid] = meters.wrdDstSlice.value()[0];
    sliceTknDst[tid] = meters.tknDstSlice.value()[0];
    sliceTime[tid] = meters.timer.value();
  };

  /* Spread threades */
  // TODO possibly try catch for futures to proper logging of all errors
  // https://github.com/facebookresearch/gtn/blob/master/gtn/parallel/parallel_map.h#L154
  auto startThreadsAndJoin = [&run](int nThreads) {
    if (nThreads == 1) {
      run(0);
    } else if (nThreads > 1) {
      std::vector<std::future<void>> futs(nThreads);
      fl::ThreadPool threadPool(nThreads);
      for (int i = 0; i < nThreads; i++) {
        futs[i] = threadPool.enqueue(run, i);
      }
      for (int i = 0; i < nThreads; i++) {
        futs[i].get();
      }
    } else {
      LOG(FATAL) << "Invalid negative FLAGS_nthread_decoder_am_forward";
    }
  };
  auto timer = fl::TimeMeter();
  timer.resume();
  startThreadsAndJoin(FLAGS_nthread_decoder_am_forward);
  timer.stop();

  int totalTokens = 0, totalWords = 0, totalSamples = 0;
  for (int i = 0; i < FLAGS_nthread_decoder_am_forward; i++) {
    totalTokens += sliceNumTokens[i];
    totalWords += sliceNumWords[i];
    totalSamples += sliceNumSamples[i];
  }
  double totalWer = 0, totalTer = 0, totalTime = 0;
  for (int i = 0; i < FLAGS_nthread_decoder_am_forward; i++) {
    totalWer += sliceWrdDst[i];
    totalTer += sliceTknDst[i];
    totalTime += sliceTime[i];
  }
  if (totalWer > 0 && totalWords == 0) {
    totalWer = std::numeric_limits<double>::infinity();
  } else {
    totalWer = totalWords > 0 ? totalWer / totalWords * 100. : 0.0;
  }
  if (totalTer > 0 && totalTokens == 0) {
    totalTer = std::numeric_limits<double>::infinity();
  } else {
    totalTer = totalTokens > 0 ? totalTer / totalTokens * 100. : 0.0;
  }
  LOG(INFO) << "------";
  LOG(INFO) << "[Test " << FLAGS_test << " (" << totalSamples << " samples) in "
            << timer.value() << "s (actual decoding time "
            << std::setprecision(3) << totalTime / totalSamples
            << "s/sample) -- WER: " << std::setprecision(6) << totalWer
            << "\%, TER: " << totalTer << "\%]";

  return 0;
}