int main()

in recipes/local_prior_match/Train_lpm.cpp [30:456]


int main(int argc, char** argv) {
  google::InitGoogleLogging(argv[0]);
  google::InstallFailureSignalHandler();
  std::string exec(argv[0]);

  gflags::SetUsageMessage(
      "Usage: \n " + exec + " train [flags]\n or " + std::string() +
      " continue [directory] [flags]\n or " + std::string(argv[0]) +
      " fork [directory/model] [flags]");

  if (argc <= 1) {
    LOG(FATAL) << gflags::ProgramUsage();
  }

  /* ===================== Parse Options ===================== */
  auto config = setFlags(argc, argv);

  int runIdx = std::stoi(config[kRunIdx]);
  std::string reloadPath = config[kReloadPath];
  std::string propPath = config[kPropPath];
  int startEpoch = std::stoi(config[kStartEpoch]);
  int startIter = std::stoi(config[kStartIter]);
  std::string runPath = config[kRunPath];
  std::string runStatus = config[kRunStatus];

  /* ================ Set up distributed environment ================ */
  af::setSeed(FLAGS_seed);
  af::setFFTPlanCacheSize(FLAGS_fftcachesize);

  std::shared_ptr<fl::Reducer> reducer = nullptr;
  if (FLAGS_enable_distributed) {
    initDistributed(
        FLAGS_world_rank,
        FLAGS_world_size,
        FLAGS_max_devices_per_node,
        FLAGS_rndv_filepath);
    reducer = std::make_shared<fl::CoalescingReducer>(
        1.0 / fl::getWorldSize(), true, true);
  }

  int worldRank = fl::getWorldRank();
  int worldSize = fl::getWorldSize();
  bool isMaster = (worldRank == 0);

  LOG_MASTER(INFO) << "Gflags after parsing \n" << serializeGflags("; ");
  LOG_MASTER(INFO) << "Experiment path: " << runPath;
  LOG_MASTER(INFO) << "Experiment runidx: " << runIdx;

  /* ===================== Create Dictionary & Lexicon ===================== */
  auto dictPath = pathsConcat(FLAGS_tokensdir, FLAGS_tokens);
  Dictionary amDict(dictPath);
  // Setup-specific modifications
  if (FLAGS_eostoken) {
    amDict.addEntry(kEosToken);
  }
  int numClasses = amDict.indexSize();
  LOG_MASTER(INFO) << "Number of classes (network) = " << numClasses;

  DictionaryMap dicts;
  dicts.insert({kTargetIdx, amDict});

  // Note: fairseq vocab should start with:
  // <fairseq_style> - 0 <pad> - 1, kEosToken - 2, kUnkToken - 3
  Dictionary lmDict(FLAGS_lmdict);
  lmDict.setDefaultIndex(lmDict.getIndex(kUnkToken));

  auto lexicon = loadWords(FLAGS_lexicon, FLAGS_maxword);

  /* =========== Create Network & Optimizers / Reload Snapshot ============ */
  std::shared_ptr<fl::Module> network;
  std::shared_ptr<Seq2SeqCriterion> criterion;
  std::shared_ptr<fl::FirstOrderOptimizer> netoptim;

  if (runStatus == kTrainMode) {
    auto archfile = pathsConcat(FLAGS_archdir, FLAGS_arch);
    LOG_MASTER(INFO) << "Loading architecture file from " << archfile;
    auto numFeatures = getSpeechFeatureSize();

    network = createW2lSeqModule(archfile, numFeatures, numClasses);
    criterion = std::make_shared<Seq2SeqCriterion>(
        buildSeq2Seq(numClasses, amDict.getIndex(kEosToken)));
  } else {
    std::unordered_map<std::string, std::string> cfg; // unused
    std::shared_ptr<SequenceCriterion> base_criterion;
    W2lSerializer::load(reloadPath, cfg, network, base_criterion, netoptim);
    criterion = std::dynamic_pointer_cast<Seq2SeqCriterion>(base_criterion);
  }

  // create LM
  std::shared_ptr<fl::Module> lmNetwork;
  W2lSerializer::load(FLAGS_lm, lmNetwork);
  auto dictIndexMap = genTokenDictIndexMap(amDict, lmDict);
  auto lm = std::make_shared<LMWrapper>(
      lmNetwork, dictIndexMap, lmDict.getIndex(kEosToken));

  LOG_MASTER(INFO) << "[Network] " << network->prettyString();
  LOG_MASTER(INFO) << "[Network Params: " << numTotalParams(network) << "]";
  LOG_MASTER(INFO) << "[Criterion] " << criterion->prettyString();
  LOG_MASTER(INFO) << "[Criterion Params: " << numTotalParams(criterion) << "]";
  LOG_MASTER(INFO) << "[LM] " << lm->prettyString();
  LOG_MASTER(INFO) << "[LM Params: " << numTotalParams(lm) << "]";

  if (runStatus != kContinueMode) {
    netoptim = initOptimizer(
        {network, criterion},
        FLAGS_netoptim,
        FLAGS_lr,
        FLAGS_momentum,
        FLAGS_weightdecay);
  }
  LOG_MASTER(INFO) << "[Optimizer] " << netoptim->prettyString();

  /* =========== Load Proposal Network ============ */
  LOG(INFO) << "Load proposal model from " << propPath;
  std::unordered_map<std::string, std::string> propcfg; // unused
  std::shared_ptr<fl::Module> propnet;
  std::shared_ptr<SequenceCriterion> base_propcrit;
  std::shared_ptr<Seq2SeqCriterion> propcrit;

  W2lSerializer::load(propPath, propcfg, propnet, base_propcrit);
  propcrit = std::dynamic_pointer_cast<Seq2SeqCriterion>(base_propcrit);

  /* ===================== Create Dataset ===================== */
  auto pairedDs = createDataset(
      FLAGS_train, dicts, lexicon, FLAGS_batchsize, worldRank, worldSize);
  auto unpairedAudioDs = createDataset(
      FLAGS_trainaudio,
      dicts,
      lexicon,
      FLAGS_unpairedBatchsize,
      worldRank,
      worldSize);

  if (FLAGS_noresample) {
    LOG_MASTER(INFO) << "Shuffling trainset";
    pairedDs->shuffle(FLAGS_seed);
    unpairedAudioDs->shuffle(FLAGS_seed);
  }

  auto trainEvalIds =
      getTrainEvalIds(pairedDs->size(), FLAGS_pcttraineval, FLAGS_seed);

  auto validSets = split(',', trim(FLAGS_valid));
  std::unordered_map<std::string, std::shared_ptr<W2lDataset>> validds;
  for (const auto& s : validSets) {
    auto ts = splitOnAnyOf(":", s);
    auto setKey = ts.size() == 1 ? s : ts[0];
    auto setValue = ts.size() == 1 ? s : ts[1];

    validds[setKey] = createDataset(
        setValue, dicts, lexicon, FLAGS_batchsize, worldRank, worldSize);
  }

  /* ===================== Training Dataset Scheduler ===================== */
  DataScheduler trainDscheduler(
      {pairedDs, unpairedAudioDs},
      {kParallelData, kUnpairedAudio},
      {FLAGS_pairediter, FLAGS_audioiter},
      startEpoch + 1);

  int64_t nItersPerEpoch = FLAGS_pairediter + FLAGS_audioiter;

  /* ===================== Meters ===================== */
  SSLTrainMeters meters;
  for (const auto& s : validds) {
    meters.valid[s.first] = SSLDatasetMeters();
  }

  /* ===================== Logging ===================== */
  bool logOnEpoch = FLAGS_reportiters == 0;
  LogHelper logHelper(runIdx, runPath, isMaster, logOnEpoch);
  logHelper.saveConfig(config);
  logHelper.writeHeader(meters);

  /* ===================== Hooks ===================== */
  if (reducer) {
    fl::distributeModuleGrads(network, reducer);
    fl::distributeModuleGrads(criterion, reducer);
  }

  fl::allReduceParameters(network);
  fl::allReduceParameters(criterion);

  /* ===================== Training starts ===================== */
  int64_t curEpoch = startEpoch;
  int64_t curIter = startIter;
  bool isPairedData;
  network->train();
  criterion->train();
  lm->eval();
  propnet->eval();
  propcrit->eval();

  logHelper.saveModel("prop.bin", propcfg, propnet, propcrit);
  runEval(propnet, propcrit, validds, meters, dicts[kTargetIdx]);
  syncMeter(meters);
  double properr = avgValidErr(meters);
  LOG_MASTER(INFO) << "Initial ProposalNetwork Err = " << properr;

  resetTrainMeters(meters);

  while (curEpoch < FLAGS_iter) {
    double lrScale = std::pow(FLAGS_gamma, curEpoch / FLAGS_stepsize);
    netoptim->setLr(lrScale * FLAGS_lr);

    ++curEpoch;
    af::sync();
    meters.timer[kSampleTimer].resume();
    meters.timer[kRuntime].resume();
    meters.timer[kTimer].resume();
    LOG_MASTER(INFO) << "Epoch " << curEpoch << " started!";
    LOG_MASTER(INFO) << "  Learning rate = " << netoptim->getLr();

    int scheduleIter = 0;
    while (scheduleIter < nItersPerEpoch) {
      auto sample = trainDscheduler.get();
      isPairedData = af::allTrue<bool>(sample[kDataTypeIdx] == kParallelData);
      ++curIter;
      ++scheduleIter;
      af::sync();
      int bs = isPairedData ? FLAGS_batchsize : FLAGS_unpairedBatchsize;

      meters.timer[kTimer].incUnit();
      meters.timer[kSampleTimer].stopAndIncUnit();
      meters.stats.add(sample[kInputIdx], sample[kTargetIdx]);
      if (af::anyTrue<bool>(af::isNaN(sample[kInputIdx])) ||
          af::anyTrue<bool>(af::isNaN(sample[kTargetIdx]))) {
        LOG(FATAL) << "Sample has NaN values";
      }

      // forward
      meters.timer[kFwdTimer].resume();
      auto output = network->forward({fl::input(sample[kInputIdx])}).front();
      af::sync();

      fl::Variable loss;
      fl::Variable lment;
      auto targets = fl::noGrad(sample[kTargetIdx]);
      auto tgtLen = getTargetLength(
          targets.array(), dicts[kTargetIdx].getIndex(kEosToken));
      if (isPairedData) {
        meters.timer[kCritFwdTimer].resume();
        loss = criterion->forward({output, targets}).front();

        if (af::anyTrue<bool>(af::isNaN(loss.array()))) {
          LOG(FATAL) << "ASR loss has NaN values";
        }
        meters.train.values[kASRLoss].add(loss.array());
        meters.timer[kCritFwdTimer].stopAndIncUnit();
      } else {
        fl::Variable lmLogprob;
        meters.timer[kBeamTimer].resume();
        std::vector<std::vector<int>> paths;
        std::vector<int> hypoNums;
        auto propoutput =
            propnet->forward({fl::input(sample[kInputIdx])}).front();
        std::tie(paths, hypoNums) = batchBeamSearch(
            propoutput, propcrit, dicts[kTargetIdx].getIndex(kEosToken));
        meters.timer[kBeamTimer].stopAndIncUnit();

        auto refLen = afToVector<int>(tgtLen);
        std::tie(paths, hypoNums) = filterBeamByLength(paths, hypoNums, refLen);
        auto hypoNumsArr =
            af::array(af::dim4(hypoNums.size()), hypoNums.data());
        af::array remIdx = af::sort(af::where(hypoNumsArr));
        int remBs = remIdx.dims()[0];

        if (remBs == 0) {
          LOG(INFO) << "WARNING : using a made-up loss because of empty batch";
          tgtLen = af::constant(0, {1}, s32);
          // create a made-up loss with 0 value that is a function of
          // parameters to train, so the grad will be all 0.
          loss = criterion->forward({output, fl::noGrad(sample[kTargetIdx])})
                     .front();
          loss = 0.0 * loss;
        } else {
          targets = fl::noGrad(
              batchTarget(paths, dicts[kTargetIdx].getIndex(kEosToken)));
          tgtLen = getTargetLength(
              targets.array(), dicts[kTargetIdx].getIndex(kEosToken));

          meters.timer[kLMFwdTimer].resume();
          lmLogprob =
              fl::negate(lm->forward({targets, fl::noGrad(tgtLen)}).front());
          meters.timer[kLMFwdTimer].stopAndIncUnit();

          meters.timer[kBeamFwdTimer].resume();
          hypoNums = afToVector<int>(hypoNumsArr(remIdx));
          output =
              batchEncoderOutput(hypoNums, output(af::span, af::span, remIdx));
          loss = criterion->forward({output, targets}).front();

          auto lmRenormProb = adjustProb(lmLogprob, hypoNums, true, true);
          loss = FLAGS_lmweight * lmRenormProb * loss;
          meters.timer[kBeamFwdTimer].stopAndIncUnit();

          meters.values[kLen].add(tgtLen);
          meters.values[kNumHypos].add(static_cast<double>(paths.size()));

          lment = entropy(lmRenormProb) / static_cast<float>(hypoNums.size());
          meters.values[kLMEnt].add(lment.array());
          meters.values[kLMScore].add(lmLogprob.array());

          if (af::anyTrue<bool>(af::isNaN(loss.array()))) {
            LOG(FATAL) << "LPM loss has NaN values";
          }
          meters.values[kLPMLoss].add(loss.array());
        }
      }

      af::sync();
      meters.timer[kFwdTimer].stopAndIncUnit();
      meters.values[kFullLoss].add(loss.array());

      // compute training error rate from parallel data
      if (isPairedData) {
        auto globalBatchIdx = afToVector<int64_t>(sample[kGlobalBatchIdx]);
        if (trainEvalIds.find(globalBatchIdx[0]) != trainEvalIds.end()) {
          evalOutput(
              output.array(),
              sample[kTargetIdx],
              meters.train.edits,
              dicts[kTargetIdx],
              criterion);
        }
      }

      // backward
      meters.timer[kBwdTimer].resume();
      netoptim->zeroGrad();
      lm->zeroGrad();

      loss.backward();
      if (reducer) {
        reducer->finalize();
      }

      af::sync();
      meters.timer[kBwdTimer].stopAndIncUnit();
      meters.timer[kOptimTimer].resume();

      // scale down gradients by batchsize note that the original batchsize
      // bs is used instead of remBs, since different workers may have
      // different remBs. for the sake of simplicity we just use bs.
      for (const auto& p : network->params()) {
        if (!p.isGradAvailable()) {
          continue;
        }
        p.grad() = p.grad() / bs;
      }
      for (const auto& p : criterion->params()) {
        if (!p.isGradAvailable()) {
          continue;
        }
        p.grad() = p.grad() / bs;
      }
      if (FLAGS_maxgradnorm > 0) {
        auto params = network->params();
        auto critparams = criterion->params();
        params.insert(params.end(), critparams.begin(), critparams.end());
        fl::clipGradNorm(params, FLAGS_maxgradnorm);
      }

      netoptim->step();
      af::sync();
      meters.timer[kOptimTimer].stopAndIncUnit();
      meters.timer[kSampleTimer].resume();

      auto lengths = afToVector<int>(tgtLen);
      LOG(INFO) << "[ Epoch " << curEpoch << " ]"
                << " Iter=" << scheduleIter << " isPairedData=" << isPairedData
                << " AvgLoss=" << fl::mean(loss, {0}).scalar<float>()
                << " MinLen="
                << *std::min_element(lengths.begin(), lengths.end())
                << " MaxLen="
                << *std::max_element(lengths.begin(), lengths.end());

      // checkpoint evaluation
      if ((!logOnEpoch && curIter % FLAGS_reportiters == 0) ||
          (logOnEpoch && scheduleIter == nItersPerEpoch)) {
        stopTimeMeters(meters);
        runEval(network, criterion, validds, meters, dicts[kTargetIdx]);

        config[kEpoch] = std::to_string(curEpoch);
        config[kIteration] = std::to_string(curIter);
        std::unordered_map<std::string, double> logFields(
            {{"lr", netoptim->getLr()}});
        logHelper.logAndSaveModel(
            meters, config, network, criterion, netoptim, logFields);

        resetTrainMeters(meters);
        network->train();
        criterion->train();
        meters.timer[kSampleTimer].resume();
        meters.timer[kRuntime].resume();
        meters.timer[kTimer].resume();

        // maybe update proposal network
        double newproperr = avgValidErr(meters);
        LOG_MASTER(INFO) << "ProposalNetwork:"
                         << " new=" << newproperr << " old=" << properr;
        if ((FLAGS_propupdate == kAlways) ||
            (FLAGS_propupdate == kBetter && properr > newproperr)) {
          LOG_MASTER(INFO) << "Update proposal model to the current model";
          logHelper.saveModel("prop.bin", config, network, criterion);
          properr = newproperr;

          std::string workerPropPath = logHelper.saveModel(
              format("prop_worker%03d.bin", worldRank),
              config,
              network,
              criterion,
              nullptr, // no optimizer for the proposal model
              true);
          W2lSerializer::load(workerPropPath, propcfg, propnet, base_propcrit);
          propcrit = std::dynamic_pointer_cast<Seq2SeqCriterion>(base_propcrit);
          propnet->eval();
          propcrit->eval();
        }
      }
    }
    af::sync();
  }

  LOG_MASTER(INFO) << "Finished training";
  return 0;
}