int main()

in recipes/joint_training_vox_populi/cpc/Train.cpp [222:1491]


int main(int argc, char** argv) {
  fl::init();
  std::string exec(argv[0]);
  std::vector<std::string> argvs;
  for (int i = 0; i < argc; i++) {
    argvs.emplace_back(argv[i]);
  }
  google::InitGoogleLogging(argv[0]);
  google::InstallFailureSignalHandler();
  gflags::SetUsageMessage(
      "Usage: \n " + exec + " train [flags]\n or " + exec +
      " continue [directory] [flags]\n or " + exec +
      " fork [directory/model] [flags]");

  /* ===================== Parse Options ===================== */
  int runIdx = 1; // current #runs in this path
  std::string runPath; // current experiment path
  std::string reloadPath; // path to model to reload
  std::string runStatus = argv[1];
  int64_t startEpoch = 0;
  int64_t startBatch = 0;
  int64_t supStartBatch = 0;
  int64_t unsupStartBatch = 0;
  if (argc <= 1) {
    FL_LOG(fl::FATAL) << gflags::ProgramUsage();
  }
  if (runStatus == kTrainMode) {
    FL_LOG(fl::INFO) << "Parsing command line flags";
    gflags::ParseCommandLineFlags(&argc, &argv, false);
    if (!FLAGS_flagsfile.empty()) {
      FL_LOG(fl::INFO) << "Reading flags from file " << FLAGS_flagsfile;
      gflags::ReadFromFlagsFile(FLAGS_flagsfile, argv[0], true);
    }
    gflags::ParseCommandLineFlags(&argc, &argv, false);
    runPath = FLAGS_rundir;
  } else if (runStatus == kContinueMode) {
    runPath = argv[2];
    while (fileExists(getRunFile("model_last.bin", runIdx, runPath))) {
      ++runIdx;
    }
    reloadPath = getRunFile("model_last.bin", runIdx - 1, runPath);
    FL_LOG(fl::INFO) << "reload path is " << reloadPath;
    std::unordered_map<std::string, std::string> cfg;
    std::string version;
    Serializer::load(reloadPath, version, cfg);
    auto flags = cfg.find(kGflags);
    if (flags == cfg.end()) {
      FL_LOG(fl::FATAL) << "Invalid config loaded from " << reloadPath;
    }
    FL_LOG(fl::INFO) << "Reading flags from config file " << reloadPath;
    gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
    if (argc > 3) {
      FL_LOG(fl::INFO) << "Parsing command line flags";
      FL_LOG(fl::INFO)
          << "Overriding flags should be mutable when using `continue`";
      gflags::ParseCommandLineFlags(&argc, &argv, false);
    }
    if (!FLAGS_flagsfile.empty()) {
      FL_LOG(fl::INFO) << "Reading flags from file " << FLAGS_flagsfile;
      gflags::ReadFromFlagsFile(FLAGS_flagsfile, argv[0], true);
    }
    gflags::ParseCommandLineFlags(&argc, &argv, false);
    auto epoch = cfg.find(kEpoch);
    if (epoch == cfg.end()) {
      FL_LOG(fl::WARNING)
          << "Did not find epoch to start from, starting from 0.";
    } else {
      startEpoch = std::stoi(epoch->second);
    }
    auto nbupdates = cfg.find(kUpdates);
    if (nbupdates == cfg.end()) {
      FL_LOG(fl::WARNING)
          << "Did not find #updates to start from, starting from 0.";
    } else {
      startBatch = std::stoi(nbupdates->second);
    }
  } else if (runStatus == kForkMode) {
    reloadPath = argv[2];
    std::unordered_map<std::string, std::string> cfg;
    std::string version;
    Serializer::load(reloadPath, version, cfg);
    auto flags = cfg.find(kGflags);
    if (flags == cfg.end()) {
      FL_LOG(fl::FATAL) << "Invalid config loaded from " << reloadPath;
    }

    FL_LOG(fl::INFO) << "Reading flags from config file " << reloadPath;
    gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);

    if (argc > 3) {
      FL_LOG(fl::INFO) << "Parsing command line flags";
      FL_LOG(fl::INFO)
          << "Overriding flags should be mutable when using `fork`";
      gflags::ParseCommandLineFlags(&argc, &argv, false);
    }

    if (!FLAGS_flagsfile.empty()) {
      FL_LOG(fl::INFO) << "Reading flags from file" << FLAGS_flagsfile;
      gflags::ReadFromFlagsFile(FLAGS_flagsfile, argv[0], true);
    }
    gflags::ParseCommandLineFlags(&argc, &argv, false);
    runPath = FLAGS_rundir;
  } else {
    FL_LOG(fl::FATAL) << gflags::ProgramUsage();
  }
  // Only new flags are re-serialized. Copy any values from deprecated flags to
  // new flags when deprecated flags are present and corresponding new flags
  // aren't
  handleDeprecatedFlags();
  if (!FLAGS_fl_log_level.empty()) {
    fl::Logging::setMaxLoggingLevel(fl::logLevelValue(FLAGS_fl_log_level));
  }

  fl::VerboseLogging::setMaxLoggingLevel(FLAGS_fl_vlog_level);

  af::setSeed(FLAGS_seed);
  af::setFFTPlanCacheSize(FLAGS_fftcachesize);
  fl::DynamicBenchmark::setBenchmarkMode(FLAGS_fl_benchmark_mode);

  std::shared_ptr<fl::Reducer> reducer = nullptr;
  if (FLAGS_enable_distributed) {
    initDistributed(
        FLAGS_world_rank,
        FLAGS_world_size,
        FLAGS_max_devices_per_node,
        FLAGS_rndv_filepath);
    reducer = std::make_shared<fl::CoalescingReducer>(1.0, true, true);
  }

  int worldRank = fl::getWorldRank();
  int worldSize = fl::getWorldSize();
  bool isMaster = (worldRank == 0);

  FL_LOG_MASTER(INFO) << "Gflags after parsing \n"
                      << fl::pkg::speech::serializeGflags("; ");
  FL_LOG_MASTER(INFO) << "Experiment path: " << runPath;
  FL_LOG_MASTER(INFO) << "Experiment runidx: " << runIdx;

  auto flOptimLevel = FLAGS_fl_optim_mode.empty()
      ? fl::OptimLevel::DEFAULT
      : fl::OptimMode::toOptimLevel(FLAGS_fl_optim_mode);
  fl::OptimMode::get().setOptimLevel(flOptimLevel);
  if (FLAGS_fl_amp_use_mixed_precision) {
    // Only set the optim mode to O1 if it was left empty
    FL_LOG(fl::INFO)
        << "Mixed precision training enabled. Will perform loss scaling.";
    if (FLAGS_fl_optim_mode.empty()) {
      FL_LOG(fl::INFO) << "Mixed precision training enabled with no "
                          "optim mode specified - setting optim mode to O1.";
      fl::OptimMode::get().setOptimLevel(fl::OptimLevel::O1);
    }
  }

  std::unordered_map<std::string, std::string> config = {
      {kProgramName, exec},
      {kCommandLine, join(" ", argvs)},
      {kGflags, fl::pkg::speech::serializeGflags()},
      // extra goodies
      {kUserName, getEnvVar("USER")},
      {kHostName, getEnvVar("HOSTNAME")},
      {kTimestamp, getCurrentDate() + ", " + getCurrentDate()},
      {kRunIdx, std::to_string(runIdx)},
      {kRunPath, runPath}};

  auto validSets = split(',', trim(FLAGS_valid));
  std::vector<std::pair<std::string, std::string>> validTagSets;
  for (const auto& s : validSets) {
    // assume the format is tag:filepath
    auto ts = splitOnAnyOf(":", s);
    if (ts.size() == 1) {
      validTagSets.emplace_back(std::make_pair(s, s));
    } else {
      validTagSets.emplace_back(std::make_pair(ts[0], ts[1]));
    }
  }

  /* ===================== Create Dictionary & Lexicon ===================== */
  auto dictPath = FLAGS_tokens;
  if (dictPath.empty() || !fileExists(dictPath)) {
    throw std::runtime_error("Invalid dictionary filepath specified.");
  }
  Dictionary tokenDict(dictPath);
  // Setup-specific modifications
  for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
    tokenDict.addEntry("<" + std::to_string(r) + ">");
  }
  // ctc expects the blank label last
  if (FLAGS_criterion2 == kCtcCriterion) {
    tokenDict.addEntry(kBlankToken);
  }
  bool isSeq2seqCrit = FLAGS_criterion == kSeq2SeqTransformerCriterion ||
      FLAGS_criterion == kSeq2SeqRNNCriterion;
  if (isSeq2seqCrit) {
    tokenDict.addEntry(fl::pkg::speech::kEosToken);
    tokenDict.addEntry(fl::lib::text::kPadToken);
  }
  if (FLAGS_codedim == 0 || FLAGS_contextdim == 0) {
    throw std::runtime_error("Please specify encoder and context dims");
  }

  int numClasses = tokenDict.indexSize();
  FL_LOG(fl::INFO) << "Number of classes (network): " << numClasses;
  int numQuant = FLAGS_npieces * FLAGS_nunits;
  FL_LOG(fl::INFO) << "Number of quantized tokens (network): " << numQuant;

  Dictionary wordDict;
  LexiconMap lexicon;
  if (!FLAGS_lexicon.empty()) {
    lexicon = loadWords(FLAGS_lexicon, FLAGS_maxword);
    wordDict = createWordDict(lexicon);
    FL_LOG(fl::INFO) << "Number of words: " << wordDict.indexSize();
  }

  DictionaryMap dicts = {{kTargetIdx, tokenDict}, {kWordIdx, wordDict}};

  /* =========== Create Network & Optimizers / Reload Snapshot ============ */
  std::shared_ptr<fl::Sequential> network;
  std::shared_ptr<fl::Sequential> _network;
  std::shared_ptr<fl::Sequential> _feat_network;
  std::shared_ptr<fl::Linear> mtl_classifier;
  // unsupervised criterion
  std::shared_ptr<SequenceCriterion> criterion;
  // supervised criterion (all variables ending with 2 are supervised)
  std::shared_ptr<SequenceCriterion> criterion2;
  std::shared_ptr<fl::FirstOrderOptimizer> netoptim;
  std::shared_ptr<fl::FirstOrderOptimizer> netoptim2;
  std::shared_ptr<fl::FirstOrderOptimizer> critoptim;
  std::shared_ptr<fl::FirstOrderOptimizer> critoptim2;
  std::shared_ptr<fl::FirstOrderOptimizer> mtloptim;
  std::unordered_map<std::string, std::string> cfg;
  std::unordered_map<std::string, std::string> _cfg;
  std::map<std::string, unsigned int> mtl_mapping;

  FL_LOG(fl::INFO) << "SAUG";
  auto saug = std::make_shared<w2l::CPCSpecAugment>(
      FLAGS_contextdim, // default 80
      64,
      6,
      FLAGS_masklength,
      FLAGS_saug_maskprob * FLAGS_masklength,
      1);

  FL_LOG(fl::INFO) << "SAUG Done";

  auto scalemode = getCriterionScaleMode(FLAGS_onorm, FLAGS_sqnorm);

  FeatureParams featParams(
      FLAGS_samplerate,
      FLAGS_framesizems,
      FLAGS_framestridems,
      FLAGS_filterbanks,
      FLAGS_lowfreqfilterbank,
      FLAGS_highfreqfilterbank,
      FLAGS_mfcccoeffs,
      kLifterParam /* lifterparam */,
      FLAGS_devwin /* delta window */,
      FLAGS_devwin /* delta-delta window */);
  featParams.useEnergy = false;
  featParams.usePower = false;
  featParams.zeroMeanFrame = false;
  auto featureRes =
      getFeatureType(FLAGS_features_type, FLAGS_channels, featParams);
  int numFeatures = featureRes.first;
  FeatureType featType = featureRes.second;

  if (runStatus == kTrainMode) {
    // order of arch (network) files: encoder, context, predict
    std::vector<std::string> archfiles = split(',', trim(FLAGS_arch));
    network = std::make_shared<fl::Sequential>();

    FL_LOG(fl::INFO) << "Building the network";

    if (FLAGS_pretrainmodel.length() > 0) {
      FL_LOG(fl::INFO) << "Pretrain";
      std::string version;
      network = std::make_shared<fl::Sequential>();
      Serializer::load(FLAGS_pretrainmodel, version, _cfg, _network, criterion);
      FL_LOG(fl::INFO) << "Loaded";
      PartialLoading(-1, _network, network);
      FL_LOG(fl::INFO) << "[Criterion] " << criterion->prettyString();
    } else {
      FL_LOG(fl::INFO) << "Loading architecture file from " << archfiles[0];
      network->add(w2l::cpc::buildSequentialModule(
          archfiles[0], numFeatures, FLAGS_codedim));
      // 2 extra layers between encoder and context in order to perform
      // operations on
      // intermediate activations
      network->add(std::make_shared<fl::LayerNorm>(std::vector<int>{0, 3}));
      network->add(
          std::make_shared<fl::Linear>(FLAGS_codedim, FLAGS_contextdim));
      FL_LOG(fl::INFO) << "Loading architecture file from " << archfiles[1];
      network->add(w2l::cpc::buildSequentialModule(
          archfiles[1], FLAGS_contextdim, FLAGS_contextdim));
    }
    FL_LOG(fl::INFO) << "Loading architecture file from " << archfiles[2];
    network->add(w2l::cpc::buildSequentialModule(
        archfiles[2], FLAGS_contextdim, numClasses));

    if (FLAGS_criterion2 == kCtcCriterion) {
      criterion2 = std::make_shared<CTCLoss>(scalemode);
    } else {
      FL_LOG(fl::FATAL) << "unimplemented criterion";
    }
    if ((FLAGS_pretrainmodel.length() == 0) &&
        (FLAGS_criterion == kCPCCriterion)) {
      criterion = std::make_shared<CPCCriterion>(
          FLAGS_codedim,
          FLAGS_contextdim,
          FLAGS_mutualdim,
          FLAGS_noffset,
          FLAGS_nunits,
          FLAGS_npieces,
          FLAGS_nnegativesamples,
          FLAGS_nbuffersamples,
          FLAGS_temperature);
      FL_LOG(fl::INFO) << "CPC criterion loaded";
    }
  } else if (runStatus == kForkMode) {
    FL_LOG(fl::INFO) << "Fork mode";
    std::unordered_map<std::string, std::string> cfg; // unused
    std::string version;
    Serializer::load(reloadPath, version, cfg, network, criterion);
  } else { // kContinueMode
    std::unordered_map<std::string, std::string> cfg; // unused
    std::string version;
    Serializer::load(
        reloadPath,
        version,
        cfg,
        network,
        criterion,
        criterion2,
        netoptim,
        netoptim2,
        critoptim,
        critoptim2);
  }
  FL_LOG(fl::INFO) << "[Network] " << network->prettyString();
  FL_LOG(fl::INFO) << "[Network Params: " << numTotalParams(network) << "]";
  FL_LOG(fl::INFO) << "[Criterion] " << criterion->prettyString();
  FL_LOG(fl::INFO) << "[Criterion2] " << criterion2->prettyString();

  if (runStatus == kTrainMode || runStatus == kForkMode) {
    netoptim = initOptimizer(
        {network}, FLAGS_netoptim, FLAGS_lr, FLAGS_momentum, FLAGS_weightdecay);
    netoptim2 = initOptimizer(
        {network},
        FLAGS_netoptim2,
        FLAGS_lr2,
        FLAGS_momentum,
        FLAGS_weightdecay);
    critoptim =
        initOptimizer({criterion}, FLAGS_critoptim, FLAGS_lrcrit, 0.0, 0.0);
    critoptim2 =
        initOptimizer({criterion2}, FLAGS_critoptim2, FLAGS_lrcrit2, 0.0, 0.0);
  }
  FL_LOG(fl::INFO) << "[Network Optimizer] " << netoptim->prettyString();
  FL_LOG(fl::INFO) << "[Network2 Optimizer] " << netoptim2->prettyString();
  FL_LOG(fl::INFO) << "[Criterion Optimizer] " << critoptim->prettyString();
  FL_LOG(fl::INFO) << "[Criterion2 Optimizer] " << critoptim2->prettyString();

  TrainMeters meters;
  TrainMeters meters2;
  for (const auto& s : validTagSets) {
    meters.valid[s.first] = DatasetMeters();
    meters2.valid[s.first] = DatasetMeters();
  }

  // best perf so far on valid datasets
  std::unordered_map<std::string, double> validminerrs;
  for (const auto& s : validTagSets) {
    validminerrs[s.first] = DBL_MAX;
  }
  std::unordered_map<std::string, double> validWerWithDecoder;

  /* =========== Create MTLLoss module ==================================== */
  if (!FLAGS_mtllossmapping.empty()) {
    FL_LOG(fl::INFO) << "Building the MTL Loss";
    FL_LOG(fl::INFO) << "Loading " << FLAGS_mtllossmapping;
    mtl_mapping = asr4real::loadMapping(FLAGS_mtllossmapping);
    const int n_categories = mtl_mapping.size();

    mtl_classifier =
        std::make_shared<fl::Linear>(FLAGS_contextdim, n_categories);

    // TODO : update
    mtloptim = initOptimizer(
        {mtl_classifier}, FLAGS_critoptim, FLAGS_lrcrit, 0.0, 0.0);

    FL_LOG(fl::INFO) << "[MTL Classifier] " << mtl_classifier->prettyString();
    FL_LOG(fl::INFO) << "[MTL Optimizer] " << mtloptim->prettyString();
  }

  /* ===================== Logging ===================== */
  std::ofstream logFile;
  if (isMaster) {
    dirCreate(runPath);
    logFile.open(getRunFile("log", runIdx, runPath));
    if (!logFile.is_open()) {
      FL_LOG(fl::FATAL) << "failed to open log file for writing";
    }
    // write config
    std::ofstream configFile(getRunFile("config", runIdx, runPath));
    cereal::JSONOutputArchive ar(configFile);
    ar(CEREAL_NVP(config));
  }

  auto logStatus = [&logFile, isMaster](
                       TrainMeters& mtrs,
                       std::unordered_map<std::string, double>&
                           validWerWithDecoder,
                       int64_t epoch,
                       int64_t nupdates,
                       double lr,
                       double lrcrit,
                       double scaleFactor) {
    syncMeter(mtrs);

    if (isMaster) {
      auto logMsg = getLogString(
          mtrs, validWerWithDecoder, epoch, nupdates, lr, lrcrit, scaleFactor);
      FL_LOG_MASTER(INFO) << logMsg;
      appendToLog(logFile, logMsg);
    }
  };

  auto saveModels = [&](int iter, int totalupdates, bool saveValid) {
    if (isMaster) {
      // Save last epoch
      config[kEpoch] = std::to_string(iter);
      config[kUpdates] = std::to_string(totalupdates);

      std::string filename;
      if (FLAGS_itersave) {
        filename =
            getRunFile(format("model_iter_%03d.bin", iter), runIdx, runPath);
        Serializer::save(
            filename,
            FL_APP_ASR_VERSION,
            config,
            network,
            criterion,
            criterion2,
            netoptim,
            netoptim2,
            critoptim,
            critoptim2);
      }

      // save last model
      filename = getRunFile("model_last.bin", runIdx, runPath);
      Serializer::save(
          filename,
          FL_APP_ASR_VERSION,
          config,
          network,
          criterion,
          criterion2,
          netoptim,
          netoptim2,
          critoptim,
          critoptim2);

      // save if better than ever for one valid (using supervised meters)
      for (const auto& v : validminerrs) {
        double verr;
        verr = meters2.valid[v.first].wrdEdit.errorRate()[0];

        if ((verr > 0.01) && (verr < validminerrs[v.first])) {
          validminerrs[v.first] = verr;
          std::string cleaned_v = cleanFilepath(v.first);
          std::string vfname =
              getRunFile("model_" + cleaned_v + ".bin", runIdx, runPath);
          Serializer::save(
              vfname,
              FL_APP_ASR_VERSION,
              config,
              network,
              criterion,
              criterion2,
              netoptim,
              netoptim2,
              critoptim,
              critoptim2);
        }
      }

      auto* curMemMgr =
          fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
      if (curMemMgr) {
        curMemMgr->printInfo("Memory Manager Stats", 0 /* device id */);
      }
    }
  };

  /* ===================== Create Dataset ===================== */

  int64_t supbatchsize = FLAGS_supbatchsize;
  if (supbatchsize == 0) {
    supbatchsize = FLAGS_batchsize;
  }

  TargetGenerationConfig targetGenConfig(
      FLAGS_wordseparator,
      FLAGS_sampletarget,
      FLAGS_criterion2,
      FLAGS_surround,
      isSeq2seqCrit,
      FLAGS_replabel,
      true /* skip unk */,
      FLAGS_usewordpiece /* fallback2LetterWordSepLeft */,
      !FLAGS_usewordpiece /* fallback2LetterWordSepLeft */);

  const auto sfxConf = (FLAGS_sfx_config.empty())
      ? std::vector<sfx::SoundEffectConfig>()
      : sfx::readSoundEffectConfigFile(FLAGS_sfx_config);

  auto inputTransform = inputFeatures(
      featParams,
      featType,
      {FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx},
      sfxConf);
  auto targetTransform = targetFeatures(tokenDict, lexicon, targetGenConfig);
  auto wordTransform = wordFeatures(wordDict);
  int targetpadVal = isSeq2seqCrit
      ? tokenDict.getIndex(fl::lib::text::kPadToken)
      : kTargetPadValue;
  int wordpadVal = kTargetPadValue;
  auto padVal = std::make_tuple(0, targetpadVal, wordpadVal);

  std::vector<std::string> trainSplits = split(",", FLAGS_train, true);
  auto trainds = createDataset(
      trainSplits,
      FLAGS_datadir,
      FLAGS_batchsize,
      inputTransform,
      targetTransform,
      wordTransform,
      padVal,
      worldRank,
      worldSize,
      false, // allowEmpty
      FLAGS_batching_strategy,
      FLAGS_batching_max_duration);

  std::vector<std::string> trainSplits2 = split(",", FLAGS_train2, true);
  auto trainds2 = createDataset(
      trainSplits2,
      FLAGS_datadir,
      FLAGS_batchsize,
      inputTransform,
      targetTransform,
      wordTransform,
      padVal,
      worldRank,
      worldSize,
      false, // allowEmpty
      FLAGS_batching_strategy,
      FLAGS_batching_max_duration);

  std::map<std::string, std::shared_ptr<fl::Dataset>> validds;
  int64_t validBatchSize =
      FLAGS_validbatchsize == -1 ? FLAGS_batchsize : FLAGS_validbatchsize;
  for (const auto& s : validTagSets) {
    validds[s.first] = createDataset(
        {s.second},
        FLAGS_datadir,
        validBatchSize,
        inputTransform,
        targetTransform,
        wordTransform,
        padVal,
        worldRank,
        worldSize,
        true // allowEmpty
    );
  }

  /* ===================== Hooks ===================== */
  auto evalOutput = [&dicts, &criterion2, &isSeq2seqCrit](
                        const af::array& op,
                        const af::array& target,
                        DatasetMeters& mtr) {
    auto batchsz = op.dims(2);
    for (int b = 0; b < batchsz; ++b) {
      auto tgt = target(af::span, b);
      auto viterbipath =
          afToVector<int>(criterion2->viterbiPath(op(af::span, af::span, b)));
      auto tgtraw = afToVector<int>(tgt);

      // Remove `-1`s appended to the target for batching (if any)
      auto labellen = getTargetSize(tgtraw.data(), tgtraw.size());
      tgtraw.resize(labellen);

      // remap actual, predicted targets for evaluating edit distance error
      if (dicts.find(kTargetIdx) == dicts.end()) {
        FL_LOG(fl::FATAL) << "Dictionary not provided for target: "
                          << kTargetIdx;
      }
      auto tgtDict = dicts.find(kTargetIdx)->second;

      auto ltrPred = tknPrediction2Ltr(
          viterbipath,
          tgtDict,
          FLAGS_criterion2,
          FLAGS_surround,
          isSeq2seqCrit,
          FLAGS_replabel,
          FLAGS_usewordpiece,
          FLAGS_wordseparator);
      auto ltrTgt = tknTarget2Ltr(
          tgtraw,
          tgtDict,
          FLAGS_criterion2,
          FLAGS_surround,
          isSeq2seqCrit,
          FLAGS_replabel,
          FLAGS_usewordpiece,
          FLAGS_wordseparator);

      auto wrdPred = tkn2Wrd(ltrPred, FLAGS_wordseparator);
      auto wrdTgt = tkn2Wrd(ltrTgt, FLAGS_wordseparator);

      mtr.tknEdit.add(ltrPred, ltrTgt);
      mtr.wrdEdit.add(wrdPred, wrdTgt);
    }
  };
  auto cpc_criterion = std::dynamic_pointer_cast<CPCCriterion>(criterion);

  // masking function in unsuperised loss
  auto maskFunction = [&cpc_criterion](const fl::Variable& inp) {
    auto inpMasked =
        cpc_criterion->getMask(inp, FLAGS_maskprob, FLAGS_masklength);
    return inpMasked;
  };

  auto numMaskFunction = [&cpc_criterion]() {
    return cpc_criterion->numMask();
  };

  auto test = [&maskFunction, &evalOutput](
                  std::shared_ptr<fl::Sequential> ntwrk,
                  std::shared_ptr<SequenceCriterion> crit,
                  std::shared_ptr<fl::Dataset> validds,
                  DatasetMeters& mtrs,
                  bool pretrain) {
    ntwrk->eval();
    crit->eval();
    mtrs.tknEdit.reset();
    mtrs.wrdEdit.reset();
    mtrs.loss.reset();
    auto curValidSet = loadPrefetchDataset(
        validds, FLAGS_nthread, false /* shuffle */, 0 /* seed */);
    for (auto& batch : *curValidSet) {
      std::vector<fl::Variable> crit_input;
      int idx = 0;
      auto enc_out = fl::input(batch[kInputIdx]);
      enc_out = ntwrk->module(idx++)->forward({enc_out}).front();
      enc_out = ntwrk->module(idx++)->forward({enc_out}).front();
      fl::Variable enc_out_mask;
      // mask only in unsupervised loss forward pass
      if (pretrain) {
        enc_out_mask = maskFunction(enc_out);
      } else {
        enc_out_mask = enc_out;
      }
      enc_out_mask = ntwrk->module(idx++)->forward({enc_out_mask}).front();
      auto context_mask = w2l::cpc::forwardSequentialModuleWithPadMask(
          enc_out_mask, ntwrk->module(idx++), batch[kDurationIdx]);
      // target is not used in unsupervised loss
      if (pretrain) {
        crit_input = {enc_out, context_mask};
      } else {
        auto output = ntwrk->module(idx)->forward({context_mask}).front();
        crit_input = {output, fl::Variable(batch[kTargetIdx], false)};
        evalOutput(output.array(), batch[kTargetIdx], mtrs);
      }
      auto loss = crit->forward(crit_input).front();
      mtrs.loss.add(loss.array());
    }
  };

  auto lrSched = [](int64_t iter, int64_t totalIter, bool pretrain) {
    int64_t hold, warmup;
    if (!pretrain) {
      hold = FLAGS_suphold;
      warmup = FLAGS_supwarmup;
    } else {
      hold = FLAGS_hold;
      warmup = FLAGS_warmup;
    }
    double lrScale = 1;
    // lr schedulers (in normal operation: unsupervised loss uses warmup +
    // linear,
    // superised loss uses warmup + constant, ignore custom)
    if (iter <= warmup) {
      lrScale = ((double)iter) / warmup;
    } else if (FLAGS_lr_sched == "custom") {
      if (pretrain) {
        int64_t offset = 750000;
        int64_t target = 760000;
        if (iter < offset) {
          lrScale = FLAGS_lr_ld_final;
        } else if (iter < target) {
          auto lrTarget = FLAGS_lr_ld_final +
              ((1.0 - FLAGS_lr_ld_final) * (totalIter - target)) /
                  (totalIter - hold);
          lrScale = FLAGS_lr_ld_final +
              ((lrTarget - FLAGS_lr_ld_final) * (iter - offset)) /
                  (target - offset);
        } else {
          lrScale = FLAGS_lr_ld_final +
              ((1.0 - FLAGS_lr_ld_final) * (totalIter - iter)) /
                  (totalIter - hold);
        }
      } else {
      }
    } else if (FLAGS_lr_sched == "inv_sqrt") {
      hold = std::max(warmup, hold);
      if (iter > hold) {
        lrScale = std::sqrt((double)hold) / std::sqrt((double)iter);
      }
    } else if (FLAGS_lr_sched == "linear") {
      hold = std::max(warmup, hold);
      if (iter > hold) {
        lrScale = FLAGS_lr_ld_final +
            ((1.0 - FLAGS_lr_ld_final) * (totalIter - iter)) /
                (totalIter - hold);
      }
    } else if (FLAGS_lr_sched == "step") {
      hold = std::max(warmup + FLAGS_lr_step_decay, hold);
      int64_t power = 0;
      if (iter > hold) {
        power = 1 + (iter - hold) / FLAGS_lr_step_decay;
      }
      lrScale = std::pow(2, -((double)power));
    } else if (FLAGS_lr_sched == "constant") {
    } else {
      throw std::runtime_error("Invalid lr scheduler");
    }
    return lrScale;
  };

  auto trainEvalIds =
      getTrainEvalIds(trainds2->size(), FLAGS_pcttraineval, FLAGS_seed);

  if (reducer) {
    fl::distributeModuleGrads(network, reducer);
    fl::distributeModuleGrads(criterion, reducer);
    fl::distributeModuleGrads(criterion2, reducer);
  }

  fl::allReduceParameters(network);
  fl::allReduceParameters(criterion);
  fl::allReduceParameters(criterion2);

  auto resetTimeStatMeters = [](TrainMeters& mtrs) {
    mtrs.runtime.reset();
    mtrs.stats.reset();
    mtrs.sampletimer.reset();
    mtrs.fwdtimer.reset();
    mtrs.critfwdtimer.reset();
    mtrs.bwdtimer.reset();
    mtrs.optimtimer.reset();
    mtrs.timer.reset();
  };

  // shuffled datasets for supervised and unsupervised loss
  std::map<int, std::shared_ptr<fl::Dataset>> shuffleds;
  shuffleds[0] = nullptr;
  shuffleds[1] = nullptr;
  // scale counters and factors for each loss
  std::map<int, float> scaleFactors;
  std::map<int, unsigned int> scaleCounters;
  scaleFactors[0] = 0.0f;
  scaleFactors[1] = 0.0f;
  scaleCounters[0] = 1;
  scaleCounters[1] = 1;

  auto train = [&test,
                &logStatus,
                &validWerWithDecoder,
                &saveModels,
                &evalOutput,
                &maskFunction,
                &numMaskFunction,
                &saug,
                &lrSched,
                &validds,
                &trainEvalIds,
                &cpc_criterion,
                &resetTimeStatMeters,
                &startBatch,
                &isMaster,
                &shuffleds,
                &scaleFactors,
                &scaleCounters,
                &mtl_mapping,
                reducer](
                   std::shared_ptr<fl::Sequential> ntwrk,
                   std::shared_ptr<SequenceCriterion> crit,
                   std::shared_ptr<fl::Dataset> trainset,
                   std::shared_ptr<fl::FirstOrderOptimizer> netopt,
                   std::shared_ptr<fl::FirstOrderOptimizer> critopt,
                   std::shared_ptr<fl::Linear> mtlpredictor,
                   std::shared_ptr<fl::FirstOrderOptimizer> mtloptim,
                   TrainMeters& mtrs,
                   double initlr,
                   double initcritlr,
                   bool clampCrit,
                   bool pretrain,
                   int64_t& trainStartBatch,
                   int64_t nbatches) {
    auto runValAndSaveModel = [&](int64_t epoch,
                                  int64_t totalupdates,
                                  double lr,
                                  double lrcrit,
                                  bool saveValid,
                                  float scaleFactor) {
      mtrs.runtime.stop();
      mtrs.timer.stop();
      mtrs.sampletimer.stop();
      mtrs.fwdtimer.stop();
      mtrs.critfwdtimer.stop();
      mtrs.bwdtimer.stop();
      mtrs.optimtimer.stop();

      // valid
      for (auto& vds : validds) {
        test(ntwrk, crit, vds.second, mtrs.valid[vds.first], pretrain);
      }

      // print status
      try {
        logStatus(
            mtrs,
            validWerWithDecoder,
            epoch,
            totalupdates,
            lr,
            lrcrit,
            scaleFactor);
      } catch (const std::exception& ex) {
        FL_LOG(fl::FATAL) << "Error while writing logs: " << ex.what();
      }
      // save last and best models
      try {
        saveModels(epoch, totalupdates, saveValid);
      } catch (const std::exception& ex) {
        FL_LOG(fl::FATAL) << "Error while saving models: " << ex.what();
      }
      // reset meters for next readings
      mtrs.train.loss.reset();
      mtrs.train.tknEdit.reset();
      mtrs.train.wrdEdit.reset();
    };

    // trainIdx = 0 (unsupervised), 1 (supervised)
    int trainIdx = 1 - pretrain;
    // curBatch is number of updates for the current loss being computed
    // from beginning
    int64_t curBatch = trainStartBatch;
    float scaleFactor = scaleFactors[trainIdx];
    unsigned int scaleCounter = scaleCounters[trainIdx];
    if (scaleFactor == 0.0f) {
      scaleFactor =
          FLAGS_fl_amp_use_mixed_precision ? FLAGS_fl_amp_scale_factor : 1;
    }
    unsigned int kScaleFactorUpdateInterval =
        FLAGS_fl_amp_scale_factor_update_interval;
    unsigned int kMaxScaleFactor = FLAGS_fl_amp_max_scale_factor;
    double kMinScaleFactor = 2 * fl::kAmpMinimumScaleFactorValue;

    while (curBatch < nbatches) {
      ntwrk->train();
      crit->train();

      int64_t freeze = FLAGS_freeze;
      // iter is total number of updates from beginning
      int64_t iter = startBatch + (curBatch - trainStartBatch) + 1;
      int64_t totalIter = FLAGS_iter;
      if (!pretrain) {
        iter -= FLAGS_supdelay;
        totalIter -= FLAGS_supdelay;
      }
      double lrScale = lrSched(iter, totalIter, pretrain);
      netopt->setLr(lrScale * initlr);
      critopt->setLr(lrScale * initcritlr);

      auto datasize = trainset->size();
      // batchIdx is index in batch of current loss
      auto batchIdx = curBatch % datasize;
      // curEpoch is epoch of current loss
      auto curEpoch = 1 + curBatch / datasize;
      // printf("train %d %d %d\n", trainIdx, batchIdx, datasize);

      if ((shuffleds[trainIdx] == nullptr) || (batchIdx == 0)) {
        // testing different ways of shuffling with updated dataset pipeline
        shuffleds[trainIdx] = loadPrefetchDataset(
            trainset, FLAGS_nthread, true, pretrain + curEpoch);
      }

      // auto printInfo = isMaster;
      auto printInfo = curBatch < 100;

      af::sync();
      mtrs.sampletimer.resume();
      mtrs.runtime.resume();
      mtrs.timer.resume();
      const auto& batch = (shuffleds[trainIdx])->get(batchIdx);
      ++curBatch;

      af::sync();
      mtrs.timer.incUnit();
      mtrs.sampletimer.stopAndIncUnit();
      mtrs.stats.add(batch[kInputIdx], batch[kTargetIdx]);
      if (af::anyTrue<bool>(af::isNaN(batch[kInputIdx])) ||
          af::anyTrue<bool>(af::isNaN(batch[kTargetIdx]))) {
        FL_LOG(fl::FATAL) << "Sample has NaN values - "
                          << join(",", readSampleIds(batch[kSampleIdx]));
      }

      bool retrySample = false;
      do {
        retrySample = false;

        // forward
        mtrs.fwdtimer.resume();

        std::vector<fl::Variable> crit_input;
        fl::Variable output;
        fl::Variable l2_enc_out;
        auto enc_out = fl::input(batch[kInputIdx]);
        int idx = 0;
        enc_out = ntwrk->module(idx++)->forward({enc_out}).front();
        auto dtype = enc_out.type();
        l2_enc_out =
            reorder(mean((enc_out * enc_out).as(f32), {0, 1}), 2, 0, 1, 3);
        enc_out = ntwrk->module(idx++)->forward({enc_out}).front().as(dtype);
        fl::Variable enc_out_mask;
        if (pretrain) {
          enc_out_mask = maskFunction(enc_out.as(f32)).as(dtype);
          l2_enc_out = l2_enc_out * numMaskFunction();
        } else if (FLAGS_use_saug && (iter > FLAGS_saug_warmup)) {
          saug->setMaskEmbedding(cpc_criterion->getMaskEmbedding());
          enc_out_mask = saug->forward(enc_out.as(f32)).as(dtype);
        } else {
          enc_out_mask = enc_out;
        }
        enc_out_mask = ntwrk->module(idx++)->forward({enc_out_mask}).front();

        enc_out = fl::dropout(enc_out, FLAGS_dropout_feat);
        enc_out_mask = fl::dropout(enc_out_mask, FLAGS_dropout_feat);
        auto context_mask = w2l::cpc::forwardSequentialModuleWithPadMask(
            enc_out_mask, ntwrk->module(idx++), batch[kDurationIdx]);
        if (pretrain) {
          crit_input = {enc_out, context_mask};
        } else {
          output = ntwrk->module(idx)->forward({context_mask}).front().as(f32);
          crit_input = {output, fl::noGrad(batch[kTargetIdx])};
        }
        af::sync();
        mtrs.critfwdtimer.resume();
        auto loss = crit->forward(crit_input).front();

        if (mtlpredictor) {
          mtlpredictor->train();
          mtloptim->zeroGrad();
          fl::Variable mtl_loss = asr4real::mtl_step(
              context_mask,
              mtlpredictor,
              shuffleds[trainIdx],
              mtl_mapping,
              batchIdx);
          loss = loss + FLAGS_mtllossweight * mtl_loss;
        }
        // add l2 encoder output penalty term in unsupervised loss
        if (pretrain) {
          loss = loss + FLAGS_l2_enc_pen * l2_enc_out;
        }

        if (printInfo) {
          auto str = "loss " + std::to_string(curBatch);
          af::print(str.c_str(), loss.array());
        }

        af::sync();
        mtrs.fwdtimer.stopAndIncUnit();
        mtrs.critfwdtimer.stopAndIncUnit();

        if (FLAGS_fl_amp_use_mixed_precision) {
          ++scaleCounter;
          loss = loss * scaleFactor;
        }

        if (af::anyTrue<bool>(af::isNaN(loss.array())) ||
            af::anyTrue<bool>(af::isInf(loss.array()))) {
          if (FLAGS_fl_amp_use_mixed_precision &&
              scaleFactor >= kMinScaleFactor) {
            scaleFactor = scaleFactor / 2.0f;
            if (isMaster) {
              FL_VLOG(2) << "AMP: Scale factor decreased. New value:\t"
                         << scaleFactor;
            }
            scaleCounter = 1;
            retrySample = true;
            continue;
          } else {
            FL_LOG(fl::FATAL) << "Loss has NaN values. Samples - "
                              << join(",", readSampleIds(batch[kSampleIdx]));
          }
        }

        std::hash<std::string> hasher;
        if (!pretrain &&
            (hasher(join(",", readSampleIds(batch[kSampleIdx]))) % 100 <=
             FLAGS_pcttraineval)) {
          evalOutput(output.array(), batch[kTargetIdx], mtrs.train);
        }

        // backward
        mtrs.bwdtimer.resume();
        netopt->zeroGrad();
        critopt->zeroGrad();
        loss.backward();
        if (reducer) {
          reducer->finalize();
        }
        af::sync();
        mtrs.bwdtimer.stopAndIncUnit();

        // optimizer
        mtrs.optimtimer.resume();

        // scale down gradients by batchsize
        af::array tokenSize = af::constant(loss.dims(0), 1, f32);
        if (reducer) {
          fl::allReduce(tokenSize);
        }
        float tokenSizeScalar = tokenSize.scalar<float>();

        for (const auto& p : ntwrk->module(0)->params()) {
          // gradient of encoder is scaled and
          // only enabled in unsupervised loss or
          // if trainencoder flag is enabled in supervised loss
          if (pretrain || FLAGS_trainencoder) {
            p.grad() = p.grad() * FLAGS_grad_mult_feat;
          } else {
            p.grad() = p.grad() * 0;
          }
        }
        if (!pretrain && !FLAGS_trainencoder) {
          for (const auto& p : ntwrk->module(1)->params()) {
            p.grad() = p.grad() * 0;
          }
        }
        // gradient of context is zero if supervised loss and traincontext
        // is false
        if (!pretrain && (!FLAGS_traincontext || (iter < freeze))) {
          for (const auto& p : ntwrk->module(2)->params()) {
            p.grad() = p.grad() * 0;
          }
          for (const auto& p : ntwrk->module(3)->params()) {
            p.grad() = p.grad() * 0;
          }
        }
        int gradIdx = 0;
        for (const auto& p : ntwrk->params()) {
          if (!p.isGradAvailable()) {
            continue;
          }
          p.grad() = p.grad() / (tokenSizeScalar * scaleFactor);
          if (FLAGS_fl_amp_use_mixed_precision) {
            if (af::anyTrue<bool>(af::isNaN(p.grad().array())) ||
                af::anyTrue<bool>(af::isInf(p.grad().array()))) {
              if (scaleFactor >= kMinScaleFactor) {
                scaleFactor = scaleFactor / 2.0f;
                FL_VLOG(2) << "AMP: Scale factor decreased. New value:\t"
                           << "gradidx " << gradIdx << "\t"
                           << "grad dims " << p.grad().dims() << "\t"
                           << scaleFactor;
                retrySample = true;
                scaleCounter = 1;
                break;
              } else {
                FL_LOG(fl::FATAL)
                    << "Gradient Loss has NaN values. Samples - "
                    << join(",", readSampleIds(batch[kSampleIdx]));
              }
            }
          }
          gradIdx++;
        }

        if (retrySample) {
          mtrs.optimtimer.stop();
          continue;
        }

        mtrs.train.loss.add((loss / scaleFactor).array());

        for (const auto& p : crit->params()) {
          if (!p.isGradAvailable()) {
            continue;
          }
          p.grad() = p.grad() / (tokenSizeScalar * scaleFactor);
        }

      } while (retrySample);

      // debugging code
      // logStatus(mtrs, curEpoch, iter, netopt->getLr(), critopt->getLr());
      // if (curBatch == 10) {
      //   resetTimeStatMeters(mtrs);
      //}

      // clamp gradients
      double maxgradnorm = FLAGS_maxgradnorm;
      if (!pretrain && FLAGS_maxgradnorm2 > 0.0) {
        maxgradnorm = FLAGS_maxgradnorm2;
      }
      if (maxgradnorm > 0) {
        auto params = ntwrk->params();
        if (clampCrit) {
          auto critparams = crit->params();
          params.insert(params.end(), critparams.begin(), critparams.end());
        }
        auto gradnorm = fl::clipGradNorm(params, maxgradnorm);
        if (printInfo) {
          std::cout << "gradnorm " << curBatch << ": " << gradnorm << std::endl;
        }
      }

      // update weights
      if (lrScale > 0) {
        critopt->step();
        netopt->step();
      }
      if (lrScale > 0 && mtlpredictor) {
        mtloptim->step();
      }
      af::sync();
      mtrs.optimtimer.stopAndIncUnit();

      if (FLAGS_fl_amp_use_mixed_precision) {
        if (printInfo) {
          std::cout << "scale factor " << curBatch << ": " << scaleFactor
                    << std::endl;
        }
        if (scaleFactor < kMaxScaleFactor) {
          if (scaleCounter % kScaleFactorUpdateInterval == 0) {
            scaleFactor *= 2;
            if (isMaster) {
              FL_VLOG(2) << "AMP: Scale factor increased. New value:\t"
                         << scaleFactor;
            }
          } else {
            // scaleFactor += 2;
          }
        }
      }

      // mtrs.sampletimer.resume();
      mtrs.runtime.stop();
      mtrs.timer.stop();

      if (FLAGS_reportiters > 0 && curBatch % FLAGS_reportiters == 0) {
        runValAndSaveModel(
            curEpoch,
            curBatch,
            netopt->getLr(),
            critopt->getLr(),
            pretrain,
            scaleFactor);
        resetTimeStatMeters(mtrs);
      }

      if ((batchIdx == (datasize - 1)) ||
          (pretrain && (iter == FLAGS_supdelay)) || (iter == totalIter)) {
        runValAndSaveModel(
            curEpoch,
            iter,
            netopt->getLr(),
            critopt->getLr(),
            pretrain,
            scaleFactor);
        resetTimeStatMeters(mtrs);
      }
      af::sync();
    }
    trainStartBatch = curBatch;
    scaleFactors[trainIdx] = scaleFactor;
    scaleCounters[trainIdx] = scaleCounter;
  };

  std::cout << " *** >>> NEW CALL TO TRAIN" << std::endl;

  // loading from a previous checkpoint
  if (startBatch < FLAGS_supdelay) {
    unsupStartBatch = startBatch;
  } else if (FLAGS_twostage) {
    unsupStartBatch = FLAGS_supdelay;
    supStartBatch = (startBatch - FLAGS_supdelay);
  } else {
    unsupStartBatch = FLAGS_supdelay +
        (startBatch - FLAGS_supdelay) * FLAGS_unsupdates /
            (FLAGS_unsupdates + FLAGS_supdates);
    supStartBatch = (startBatch - FLAGS_supdelay) * FLAGS_supdates /
        (FLAGS_unsupdates + FLAGS_supdates);
  }
  // supStartBatch is number of updates of supervised loss
  // unsupStartBatch is number of updates of unsupervised loss
  startBatch = unsupStartBatch + supStartBatch;
  printf("unsup: %ld, sup: %ld\n", unsupStartBatch, supStartBatch);

  resetTimeStatMeters(meters);
  resetTimeStatMeters(meters2);

  // alternately iterate between unsupervised and supervised loss
  while (startBatch < FLAGS_iter) {
    // unsupervised loss updates for FLAGS_unsupdates iterations
    // if two_stage = true, then first do only unsupervised and then only
    // supervised
    // if two_stage = false, then always do unsupervised and do supervsied only
    // after FLAGS_supdelay iterations
    if (!FLAGS_twostage || (FLAGS_twostage && startBatch < FLAGS_supdelay)) {
      train(
          network,
          criterion,
          trainds,
          netoptim,
          critoptim,
          mtl_classifier,
          mtloptim,
          meters,
          FLAGS_lr,
          FLAGS_lrcrit,
          true /* clampCrit */,
          true,
          unsupStartBatch,
          unsupStartBatch + FLAGS_unsupdates);
      startBatch = unsupStartBatch + supStartBatch;
      if (FLAGS_twostage && (startBatch == FLAGS_supdelay)) {
        break;
      }
    }
    // supervised loss updates for FLAGS_supdates iterations
    if (startBatch >= FLAGS_supdelay) {
      train(
          network,
          criterion2,
          trainds2,
          netoptim2,
          critoptim2,
          mtl_classifier,
          mtloptim,
          meters2,
          FLAGS_lr2,
          FLAGS_lrcrit2,
          true /* clampCrit */,
          false,
          supStartBatch,
          supStartBatch + FLAGS_supdates);
      startBatch = unsupStartBatch + supStartBatch;
    }
  }

  FL_LOG_MASTER(INFO) << "Finished training";
  return 0;
}