in recipes/slimIPL/src/Train.cpp [106:1934]
int main(int argc, char** argv) {
fl::init();
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
std::string exec(argv[0]);
std::vector<std::string> argvs;
for (int i = 0; i < argc; i++) {
argvs.emplace_back(argv[i]);
}
gflags::SetUsageMessage(
"Usage: \n " + exec + " train [flags]\n or " + exec +
" continue [directory] [flags]\n or " + exec +
" fork [directory/model] [flags]");
/* ===================== Parse Options ===================== */
int runIdx = 1; // current #runs in this path
std::string runPath; // current experiment path
std::string reloadPath; // path to model to reload
std::string runStatus = argv[1];
int64_t startEpoch = 0;
int64_t startUpdate = 0;
if (argc <= 1) {
LOG(FATAL) << gflags::ProgramUsage();
}
if (runStatus == kTrainMode) {
parseCmdLineFlagsWrapper(argc, argv);
runPath = FLAGS_rundir;
} else if (runStatus == kContinueMode) {
runPath = argv[2];
while (fileExists(getRunFile("model_last.bin", runIdx, runPath))) {
++runIdx;
}
reloadPath = getRunFile("model_last.bin", runIdx - 1, runPath);
LOG(INFO) << "reload path is " << reloadPath;
std::unordered_map<std::string, std::string> cfg;
std::string version;
Serializer::load(reloadPath, version, cfg);
auto flags = cfg.find(kGflags);
if (flags == cfg.end()) {
LOG(FATAL) << "Invalid config loaded from " << reloadPath;
}
LOG(INFO) << "Reading flags from config file " << reloadPath;
gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
parseCmdLineFlagsWrapper(argc, argv);
auto epoch = cfg.find(kEpoch);
if (epoch == cfg.end()) {
LOG(WARNING) << "Did not find epoch to start from, starting from 0.";
} else {
startEpoch = std::stoi(epoch->second);
}
auto nbupdates = cfg.find(kUpdates);
if (nbupdates == cfg.end()) {
LOG(WARNING) << "Did not find #updates to start from, starting from 0.";
} else {
startUpdate = std::stoi(nbupdates->second);
}
} else if (runStatus == kForkMode) {
reloadPath = argv[2];
std::unordered_map<std::string, std::string> cfg;
std::string version;
Serializer::load(reloadPath, version, cfg);
auto flags = cfg.find(kGflags);
if (flags == cfg.end()) {
LOG(FATAL) << "Invalid config loaded from " << reloadPath;
}
LOG(INFO) << "Reading flags from config file " << reloadPath;
gflags::ReadFlagsFromString(flags->second, gflags::GetArgv0(), true);
parseCmdLineFlagsWrapper(argc, argv);
runPath = FLAGS_rundir;
} else {
LOG(FATAL) << gflags::ProgramUsage();
}
if (runPath.empty()) {
LOG(FATAL) << "'runpath' specified by --rundir, --runname cannot be empty";
}
af::setSeed(FLAGS_seed);
fl::DynamicBenchmark::setBenchmarkMode(FLAGS_fl_benchmark_mode);
std::shared_ptr<fl::Reducer> reducer = nullptr;
if (FLAGS_enable_distributed) {
fl::pkg::runtime::initDistributed(
FLAGS_world_rank,
FLAGS_world_size,
FLAGS_max_devices_per_node,
FLAGS_rndv_filepath);
reducer = std::make_shared<fl::CoalescingReducer>(1.0, true, true);
}
int worldRank = fl::getWorldRank();
int worldSize = fl::getWorldSize();
bool isMaster = (worldRank == 0);
FL_LOG_MASTER(INFO) << "Gflags after parsing \n" << serializeGflags("; ");
FL_LOG_MASTER(INFO) << "Experiment path: " << runPath;
FL_LOG_MASTER(INFO) << "Experiment runidx: " << runIdx;
// flashlight optim mode
auto flOptimLevel = FLAGS_fl_optim_mode.empty()
? fl::OptimLevel::DEFAULT
: fl::OptimMode::toOptimLevel(FLAGS_fl_optim_mode);
fl::OptimMode::get().setOptimLevel(flOptimLevel);
if (FLAGS_fl_amp_use_mixed_precision) {
// Only set the optim mode to O1 if it was left empty
LOG(INFO) << "Mixed precision training enabled. Will perform loss scaling.";
if (FLAGS_fl_optim_mode.empty()) {
LOG(INFO) << "Mixed precision training enabled with no "
"optim mode specified - setting optim mode to O1.";
fl::OptimMode::get().setOptimLevel(fl::OptimLevel::O1);
}
}
std::unordered_map<std::string, std::string> config = {
{kProgramName, exec},
{kCommandLine, join(" ", argvs)},
{kGflags, serializeGflags()},
// extra goodies
{kUserName, fl::lib::getEnvVar("USER")},
{kHostName, fl::lib::getEnvVar("HOSTNAME")},
{kTimestamp, getCurrentDate() + ", " + getCurrentDate()},
{kRunIdx, std::to_string(runIdx)},
{kRunPath, runPath}};
std::vector<std::pair<std::string, std::string>> validTagSets =
parseValidSets(FLAGS_valid);
/* ===================== Create Dictionary & Lexicon ===================== */
auto dictPath = FLAGS_tokens;
if (dictPath.empty() || !fileExists(dictPath)) {
throw std::runtime_error(
"Invalid dictionary filepath specified with "
" --tokens: \"" +
dictPath + "\"");
}
fl::lib::text::Dictionary tokenDict(dictPath);
// Setup-specific modifications
for (int64_t r = 1; r <= FLAGS_replabel; ++r) {
tokenDict.addEntry("<" + std::to_string(r) + ">");
}
// ctc expects the blank label last
if (FLAGS_criterion == kCtcCriterion) {
tokenDict.addEntry(kBlankToken);
}
bool isSeq2seqCrit = FLAGS_criterion == kSeq2SeqTransformerCriterion ||
FLAGS_criterion == kSeq2SeqRNNCriterion;
if (isSeq2seqCrit) {
tokenDict.addEntry(fl::pkg::speech::kEosToken);
tokenDict.addEntry(fl::lib::text::kPadToken);
}
int numClasses = tokenDict.indexSize();
LOG(INFO) << "Number of classes (network): " << numClasses;
fl::lib::text::Dictionary wordDict;
fl::lib::text::LexiconMap lexicon;
if (!FLAGS_lexicon.empty()) {
lexicon = fl::lib::text::loadWords(FLAGS_lexicon, FLAGS_maxword);
wordDict = fl::lib::text::createWordDict(lexicon);
LOG(INFO) << "Number of words: " << wordDict.indexSize();
}
/* ===================== Create Dataset ===================== */
std::unordered_map<std::string, std::string> plCache;
std::unordered_map<std::string, std::string> plCacheDump;
std::unordered_map<std::string, af::array> plCacheSoft;
std::unordered_map<std::string, af::array> plCacheDumpSoft;
std::vector<int> plBatchCacheFixedSize;
fl::lib::audio::FeatureParams featParams(
FLAGS_samplerate,
FLAGS_framesizems,
FLAGS_framestridems,
FLAGS_filterbanks,
FLAGS_lowfreqfilterbank,
FLAGS_highfreqfilterbank,
FLAGS_mfcccoeffs,
kLifterParam /* lifterparam */,
FLAGS_devwin /* delta window */,
FLAGS_devwin /* delta-delta window */);
featParams.useEnergy = false;
featParams.usePower = false;
featParams.zeroMeanFrame = false;
auto featureRes =
getFeatureType(FLAGS_features_type, FLAGS_channels, featParams);
int numFeatures = featureRes.first;
FeatureType featType = featureRes.second;
TargetGenerationConfig targetGenConfig(
FLAGS_wordseparator,
FLAGS_sampletarget,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
true /* skip unk */,
FLAGS_usewordpiece /* fallback2LetterWordSepLeft */,
!FLAGS_usewordpiece /* fallback2LetterWordSepLeft */);
const auto sfxConf = (FLAGS_sfx_config.empty())
? std::vector<sfx::SoundEffectConfig>()
: sfx::readSoundEffectConfigFile(FLAGS_sfx_config);
auto inputTransform = inputFeatures(
featParams,
featType,
{FLAGS_localnrmlleftctx, FLAGS_localnrmlrightctx},
sfxConf);
auto targetTransform = targetFeatures(tokenDict, lexicon, targetGenConfig);
auto wordTransform = wordFeatures(wordDict);
int targetpadVal = isSeq2seqCrit
? tokenDict.getIndex(fl::lib::text::kPadToken)
: kTargetPadValue;
int wordpadVal = kTargetPadValue;
auto padVal = std::make_tuple(0, targetpadVal, wordpadVal);
std::vector<std::string> trainSplits = fl::lib::split(",", FLAGS_train, true);
std::vector<std::string> unsupTrainSplits =
fl::lib::split(",", FLAGS_unsup_train, true);
auto trainds = createDataset(
trainSplits,
FLAGS_datadir,
FLAGS_batchsize,
inputTransform,
targetTransform,
wordTransform,
padVal,
worldRank,
worldSize,
false, // allowEmpty
FLAGS_batching_strategy,
FLAGS_batching_max_duration);
std::shared_ptr<fl::Dataset> unsupTrainds;
if (FLAGS_unsup_train != "") {
unsupTrainds = createDataset(
unsupTrainSplits,
FLAGS_unsup_datadir,
FLAGS_batchsize,
inputTransform,
targetTransform,
wordTransform,
padVal,
worldRank,
worldSize,
false, // allowEmpty
FLAGS_batching_strategy,
FLAGS_batching_max_duration);
}
LOG(INFO) << "Sup batches " << trainds->size();
if (unsupTrainds != nullptr) {
LOG(INFO) << "Unsup batches " << unsupTrainds->size();
}
std::map<std::string, std::shared_ptr<fl::Dataset>> validds;
int64_t validBatchSize =
FLAGS_validbatchsize == -1 ? FLAGS_batchsize : FLAGS_validbatchsize;
for (const auto& s : validTagSets) {
validds[s.first] = createDataset(
{s.second},
FLAGS_datadir,
validBatchSize,
inputTransform,
targetTransform,
wordTransform,
padVal,
worldRank,
worldSize,
true // allowEmpty
);
}
/* =========== Create Network & Optimizers / Reload Snapshot ============ */
std::shared_ptr<fl::Module> network;
std::shared_ptr<fl::Module> networkEMA;
std::shared_ptr<SequenceCriterion> criterion;
std::shared_ptr<fl::FirstOrderOptimizer> netoptim;
std::shared_ptr<fl::FirstOrderOptimizer> critoptim;
std::shared_ptr<fl::lib::text::LM> lm;
std::shared_ptr<WordDecodeMaster> dm;
auto scalemode = getCriterionScaleMode(FLAGS_onorm, FLAGS_sqnorm);
(void)fl::pkg::runtime::ModulePlugin(FLAGS_arch);
if (runStatus == kTrainMode) {
FL_LOG_MASTER(INFO) << "Loading architecture file from " << FLAGS_arch;
// Encoder network, works on audio
network = fl::pkg::runtime::ModulePlugin(FLAGS_arch)
.arch(numFeatures, numClasses);
networkEMA = network;
if (FLAGS_slimIPL_ema) {
networkEMA = fl::pkg::runtime::ModulePlugin(FLAGS_arch)
.arch(numFeatures, numClasses);
for (size_t i = 0; i < networkEMA->params().size(); ++i) {
auto param = network->param(i).array();
param.eval();
networkEMA->setParams(fl::Variable(param, false), i);
}
}
if (FLAGS_criterion == kCtcCriterion) {
criterion = std::make_shared<CTCLoss>(scalemode);
} else if (FLAGS_criterion == kAsgCriterion) {
criterion =
std::make_shared<ASGLoss>(numClasses, scalemode, FLAGS_transdiag);
} else if (FLAGS_criterion == kSeq2SeqRNNCriterion) {
std::vector<std::shared_ptr<AttentionBase>> attentions;
for (int i = 0; i < FLAGS_decoderattnround; i++) {
attentions.push_back(createAttention());
}
criterion = std::make_shared<Seq2SeqCriterion>(
numClasses,
FLAGS_encoderdim,
tokenDict.getIndex(fl::pkg::speech::kEosToken),
tokenDict.getIndex(fl::lib::text::kPadToken),
FLAGS_maxdecoderoutputlen,
attentions,
createAttentionWindow(),
FLAGS_trainWithWindow,
FLAGS_pctteacherforcing,
FLAGS_labelsmooth,
FLAGS_inputfeeding,
FLAGS_samplingstrategy,
FLAGS_gumbeltemperature,
FLAGS_decoderrnnlayer,
FLAGS_decoderattnround,
FLAGS_decoderdropout);
} else if (FLAGS_criterion == kSeq2SeqTransformerCriterion) {
criterion = std::make_shared<TransformerCriterion>(
numClasses,
FLAGS_encoderdim,
tokenDict.getIndex(fl::pkg::speech::kEosToken),
tokenDict.getIndex(fl::lib::text::kPadToken),
FLAGS_maxdecoderoutputlen,
FLAGS_am_decoder_tr_layers,
createAttention(),
createAttentionWindow(),
FLAGS_trainWithWindow,
FLAGS_labelsmooth,
FLAGS_pctteacherforcing,
FLAGS_am_decoder_tr_dropout,
FLAGS_am_decoder_tr_layerdrop);
} else {
LOG(FATAL) << "unimplemented criterion";
}
} else if (runStatus == kForkMode) {
std::unordered_map<std::string, std::string> cfg; // unused
std::string version;
Serializer::load(reloadPath, version, cfg, network, criterion);
if (version != FL_APP_ASR_VERSION) {
LOG(WARNING) << "Model version " << version << " and code version "
<< FL_APP_ASR_VERSION;
// TODO fix EMA
}
} else { // kContinueMode
std::unordered_map<std::string, std::string> cfg; // unused
std::string version;
Serializer::load(
reloadPath, version, cfg, network, criterion, netoptim, critoptim);
if (version != FL_APP_ASR_VERSION) {
LOG(WARNING) << "Model version " << version << " and code version "
<< FL_APP_ASR_VERSION;
}
networkEMA = network;
if (FLAGS_slimIPL_ema) {
Serializer::load(
getRunFile("model_last_ema.bin", runIdx - 1, runPath),
version,
networkEMA);
}
LOG(INFO) << "Loaded model for continue training";
if (FLAGS_slimIPL_type == "cache" || FLAGS_slimIPL_type == "pre-cache" ||
FLAGS_slimIPL_type == "fixed-pre-cache") {
LOG(INFO) << "Reading PL cache";
if (FLAGS_slimIPL_use_soft) {
std::string version;
Serializer::load(
getRunFile(
"model_last_cache_soft" + std::to_string(worldRank),
runIdx - 1,
runPath),
version,
plCacheDumpSoft);
} else {
for (int processIdx = 0; processIdx < worldSize; processIdx++) {
std::ifstream fcache;
auto cacheName = getRunFile("model_last_cache", runIdx - 1, runPath) +
std::to_string(processIdx);
if (!fileExists(cacheName)) {
LOG(INFO) << "Read cache from " << cacheName
<< "; Skip, file doesn't exist";
continue;
}
fcache.open(cacheName);
std::string line;
int count = 0;
while (getline(fcache, line)) {
auto tmp = fl::lib::split("|", line);
if (tmp.size() == 0) {
continue;
} else if (tmp.size() == 1) {
plCacheDump[tmp[0]] = "";
} else {
plCacheDump[tmp[0]] = tmp[1];
}
count++;
}
fcache.close();
LOG(INFO) << "Read cache from " << cacheName
<< " with number of samples " << count;
}
LOG(INFO) << "Reading PL cache is done; total size "
<< plCacheDump.size();
}
if (FLAGS_slimIPL_type == "fixed-pre-cache") {
std::ifstream fcacheFixed;
auto cacheName =
getRunFile("model_last_fixed_cache", runIdx - 1, runPath) +
std::to_string(worldRank);
if (!fileExists(cacheName)) {
LOG(INFO) << "Read fixed cache from " << cacheName
<< "; Skip, file doesn't exist";
} else {
fcacheFixed.open(cacheName);
while (fcacheFixed.peek() != EOF &&
plBatchCacheFixedSize.size() <
FLAGS_slimIPL_fixed_cache_updates) {
int val;
fcacheFixed >> val;
plBatchCacheFixedSize.push_back(val);
if (val < 0) {
LOG(FATAL) << "Wrong cache file, negative indices!!! " << val;
}
}
LOG(INFO) << "Reading PL fixed cache is done; total size "
<< plBatchCacheFixedSize.size();
fcacheFixed.close();
}
}
}
}
FL_LOG_MASTER(INFO) << "[Network] " << network->prettyString();
FL_LOG_MASTER(INFO) << "[Network Params: " << numTotalParams(network) << "]";
FL_LOG_MASTER(INFO) << "[Criterion] " << criterion->prettyString();
if (!FLAGS_lm.empty()) {
FL_LOG_MASTER(INFO) << "[Beam-search Decoder] Constructing language model "
"and beam search decoder";
std::vector<float> dummyTransition;
if (FLAGS_decodertype == "wrd" && FLAGS_lmtype == "kenlm" &&
FLAGS_criterion == "ctc") {
lm = std::make_shared<fl::lib::text::KenLM>(FLAGS_lm, wordDict);
dm = std::make_shared<WordDecodeMaster>(
network,
lm,
dummyTransition,
true, // usePlugin
tokenDict,
wordDict,
DecodeMasterTrainOptions{
.repLabel = int32_t(FLAGS_replabel),
.wordSepIsPartOfToken = FLAGS_usewordpiece,
.surround = FLAGS_surround,
.wordSep = FLAGS_wordseparator,
.targetPadIdx = targetpadVal});
} else {
throw std::runtime_error(
"Other decoders are not supported yet during training");
}
}
if (runStatus == kTrainMode || runStatus == kForkMode) {
netoptim = initOptimizer(
{network}, FLAGS_netoptim, FLAGS_lr, FLAGS_momentum, FLAGS_weightdecay);
critoptim =
initOptimizer({criterion}, FLAGS_critoptim, FLAGS_lrcrit, 0.0, 0.0);
}
FL_LOG_MASTER(INFO) << "[Network Optimizer] " << netoptim->prettyString();
FL_LOG_MASTER(INFO) << "[Criterion Optimizer] " << critoptim->prettyString();
double initLinNetlr = FLAGS_linlr >= 0.0 ? FLAGS_linlr : FLAGS_lr;
double initLinCritlr =
FLAGS_linlrcrit >= 0.0 ? FLAGS_linlrcrit : FLAGS_lrcrit;
std::shared_ptr<LinSegCriterion> linseg;
std::shared_ptr<fl::FirstOrderOptimizer> linNetoptim;
std::shared_ptr<fl::FirstOrderOptimizer> linCritoptim;
if (FLAGS_linseg > startUpdate) {
if (FLAGS_criterion != kAsgCriterion) {
LOG(FATAL) << "linseg may only be used with ASG criterion";
}
linseg = std::make_shared<LinSegCriterion>(numClasses, scalemode);
linseg->setParams(criterion->param(0), 0);
FL_LOG_MASTER(INFO) << "[Criterion] " << linseg->prettyString()
<< " (for first " << FLAGS_linseg - startUpdate
<< " updates)";
linNetoptim = initOptimizer(
{network},
FLAGS_netoptim,
initLinNetlr,
FLAGS_momentum,
FLAGS_weightdecay);
linCritoptim =
initOptimizer({linseg}, FLAGS_critoptim, initLinCritlr, 0.0, 0.0);
FL_LOG_MASTER(INFO) << "[Network Optimizer] " << linNetoptim->prettyString()
<< " (for first " << FLAGS_linseg - startUpdate
<< " updates)";
FL_LOG_MASTER(INFO) << "[Criterion Optimizer] "
<< linCritoptim->prettyString() << " (for first "
<< FLAGS_linseg - startUpdate << " updates)";
}
/* ===================== Meters ===================== */
slimIPL::TrainMetersMy meters;
for (const auto& s : validTagSets) {
meters.valid[s.first] = DatasetMeters();
}
// best perf so far on valid datasets
std::unordered_map<std::string, double> validminerrs;
for (const auto& s : validTagSets) {
validminerrs[s.first] = DBL_MAX;
}
std::unordered_map<std::string, double> validMinWerWithDecoder;
std::unordered_map<std::string, double> validWerWithDecoder;
if (dm) {
for (const auto& s : validTagSets) {
validMinWerWithDecoder[s.first] = DBL_MAX;
validWerWithDecoder[s.first] = DBL_MAX;
}
}
/* ===================== Logging ===================== */
std::ofstream logFile;
if (isMaster) {
fl::lib::dirCreate(runPath);
logFile.open(getRunFile("log", runIdx, runPath));
if (!logFile) {
LOG(FATAL) << "failed to open log file for writing";
}
// write config
std::ofstream configFile(getRunFile("config", runIdx, runPath));
cereal::JSONOutputArchive ar(configFile);
ar(CEREAL_NVP(config));
}
/* ===================== PL Generator ===================== */
auto tokenToWord = [&isSeq2seqCrit, &tokenDict](
const std::vector<int>& tokens,
bool isPrediction) -> std::vector<std::string> {
std::vector<std::string> letters;
if (isPrediction) {
letters = tknPrediction2Ltr(
tokens,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
} else {
letters = tknTarget2Ltr(
tokens,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
}
return tkn2Wrd(letters, FLAGS_wordseparator);
};
/* ===================== Hooks ===================== */
auto logStatus =
[&logFile, &validTagSets, isMaster](
slimIPL::TrainMetersMy& mtrs,
std::unordered_map<std::string, double>& validWerWithDecoder,
int64_t epoch,
int64_t nupdates,
double lr,
double lrcrit) {
slimIPL::syncMeter(mtrs);
if (isMaster) {
auto logMsg = getLogString(
mtrs, validWerWithDecoder, epoch, nupdates, lr, lrcrit);
FL_LOG_MASTER(INFO) << logMsg;
appendToLog(logFile, logMsg);
}
};
std::ofstream memLog;
if (FLAGS_fl_log_mem_ops_interval > 0 && isMaster) {
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
memLog.open(getRunFile("mem", runIdx, runPath));
if (!memLog) {
LOG(FATAL) << "failed to open memory log file="
<< getRunFile("mem", runIdx, runPath) << " for writing";
}
curMemMgr->setLogStream(&memLog);
curMemMgr->setLoggingEnabled(true);
curMemMgr->setLogFlushInterval(FLAGS_fl_log_mem_ops_interval);
}
}
auto saveModels = [&](int iter, int totalUpdates) {
if (FLAGS_slimIPL_type == "pre-cache" || FLAGS_slimIPL_type == "cache" ||
FLAGS_slimIPL_type == "fixed-pre-cache") {
std::ofstream fcache, fcacheFixed;
auto cacheName = getRunFile("model_last_cache", runIdx, runPath) +
std::to_string(worldRank);
if (FLAGS_slimIPL_use_soft) {
Serializer::save(
getRunFile(
"model_last_cache_soft" + std::to_string(worldRank),
runIdx,
runPath),
FL_APP_ASR_VERSION,
plCacheSoft);
} else {
fcache.open(cacheName);
for (auto const& element : plCache) {
fcache << element.first << "|" << element.second << std::endl;
}
fcache.close();
cacheName = getRunFile("model_last_fixed_cache", runIdx, runPath) +
std::to_string(worldRank);
fcacheFixed.open(cacheName);
for (auto const& element : plBatchCacheFixedSize) {
fcacheFixed << element << " ";
}
fcacheFixed.close();
}
}
if (isMaster) {
// Save last epoch
config[kEpoch] = std::to_string(iter);
config[kUpdates] = std::to_string(totalUpdates);
std::string filename;
if (FLAGS_itersave) {
filename =
getRunFile(format("model_iter_%03d.bin", iter), runIdx, runPath);
Serializer::save(
filename,
FL_APP_ASR_VERSION,
config,
network,
criterion,
netoptim,
critoptim);
}
// save last model
filename = getRunFile("model_last.bin", runIdx, runPath);
Serializer::save(
filename,
FL_APP_ASR_VERSION,
config,
network,
criterion,
netoptim,
critoptim);
if (FLAGS_slimIPL_ema) {
Serializer::save(
getRunFile("model_last_ema.bin", runIdx, runPath),
FL_APP_ASR_VERSION,
networkEMA);
}
// save if better than ever for one valid
for (const auto& v : validminerrs) {
double verr = meters.valid[v.first].wrdEdit.errorRate()[0];
if (verr < validminerrs[v.first]) {
validminerrs[v.first] = verr;
std::string cleaned_v = cleanFilepath(v.first);
std::string vfname =
getRunFile("model_" + cleaned_v + ".bin", runIdx, runPath);
Serializer::save(
vfname,
FL_APP_ASR_VERSION,
config,
network,
criterion,
netoptim,
critoptim);
}
}
// save if better than ever for one valid with lm decoding
for (const auto& v : validMinWerWithDecoder) {
double verr = validWerWithDecoder[v.first];
if (verr < validMinWerWithDecoder[v.first]) {
validMinWerWithDecoder[v.first] = verr;
std::string cleaned_v = cleanFilepath(v.first);
std::string vfname = getRunFile(
"model_" + cleaned_v + "_decoder.bin", runIdx, runPath);
Serializer::save(
vfname,
FL_APP_ASR_VERSION,
config,
network,
criterion,
netoptim,
critoptim);
}
}
// print brief stats on memory allocation (so far)
auto* curMemMgr =
fl::MemoryManagerInstaller::currentlyInstalledMemoryManager();
if (curMemMgr) {
curMemMgr->printInfo("Memory Manager Stats", 0 /* device id */);
}
}
};
auto evalOutput = [&tokenDict, &criterion, &isSeq2seqCrit](
const af::array& op,
const af::array& target,
const af::array& inputSizes,
DatasetMeters& mtr) {
auto batchsz = op.dims(2);
for (int b = 0; b < batchsz; ++b) {
auto tgt = target(af::span, b);
auto viterbipath = afToVector<int>(
criterion->viterbiPath(op(af::span, af::span, b), inputSizes.col(b)));
auto tgtraw = afToVector<int>(tgt);
// Remove `-1`s appended to the target for batching (if any)
auto labellen = getTargetSize(tgtraw.data(), tgtraw.size());
tgtraw.resize(labellen);
// remap actual, predicted targets for evaluating edit distance error
auto ltrPred = tknPrediction2Ltr(
viterbipath,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
auto ltrTgt = tknTarget2Ltr(
tgtraw,
tokenDict,
FLAGS_criterion,
FLAGS_surround,
isSeq2seqCrit,
FLAGS_replabel,
FLAGS_usewordpiece,
FLAGS_wordseparator);
auto wrdPred = tkn2Wrd(ltrPred, FLAGS_wordseparator);
auto wrdTgt = tkn2Wrd(ltrTgt, FLAGS_wordseparator);
mtr.tknEdit.add(ltrPred, ltrTgt);
mtr.wrdEdit.add(wrdPred, wrdTgt);
}
};
auto test = [&evalOutput, &dm, &lexicon, &isSeq2seqCrit, &worldRank](
std::shared_ptr<fl::Module> ntwrk,
std::shared_ptr<SequenceCriterion> crit,
std::shared_ptr<fl::Dataset> validds,
DatasetMeters& mtrs,
double& dmErr) {
ntwrk->eval();
crit->eval();
mtrs.tknEdit.reset();
mtrs.wrdEdit.reset();
mtrs.loss.reset();
auto curValidset = loadPrefetchDataset(
validds, FLAGS_nthread, false /* shuffle */, 0 /* seed */);
if (dm) {
fl::TimeMeter timer;
timer.resume();
LOG(INFO) << "[Beam-search decoder] * DM: compute emissions "
<< curValidset->size();
auto eds = dm->forward(curValidset);
LOG(INFO) << "[Beam-search decoder] * DM: decode";
std::vector<double> lmweights;
for (double lmweight = FLAGS_lmweight_low;
lmweight <= FLAGS_lmweight_high;
lmweight += FLAGS_lmweight_step) {
lmweights.push_back(lmweight);
LOG(INFO) << "LM " << lmweight;
}
std::vector<std::vector<int64_t>> wordEditDst(lmweights.size());
std::vector<std::thread> threads;
for (int i = 0; i < lmweights.size(); i++) {
threads.push_back(std::thread(
[&lmweights, &wordEditDst, &dm, &eds, &lexicon, i, worldRank]() {
af::setDevice(worldRank % 8);
double lmweight = lmweights[i];
DecodeMasterLexiconOptions opt = {
.beamSize = FLAGS_beamsize,
.beamSizeToken = FLAGS_beamsizetoken,
.beamThreshold = FLAGS_beamthreshold,
.lmWeight = lmweight,
.silScore = FLAGS_silscore,
.wordScore = FLAGS_wordscore,
.unkScore = FLAGS_unkscore,
.logAdd = FLAGS_logadd,
.silToken = FLAGS_wordseparator,
.blankToken = kBlankToken,
.unkToken = fl::lib::text::kUnkToken,
.smearMode =
(FLAGS_smearing == "max"
? fl::lib::text::SmearingMode::MAX
: fl::lib::text::SmearingMode::NONE)};
auto pds = dm->decode(eds, lexicon, opt);
// return token distance and word distance stats
wordEditDst[i] = dm->computeMetrics(pds).second;
}));
}
for (auto& thread : threads) {
thread.join();
}
dmErr = DBL_MAX;
for (int i = 0; i < lmweights.size(); i++) {
af::array currentEditDist =
af::constant((long long)(wordEditDst[i][0]), af::dim4(1, 1, 1, 1));
af::array currentTokens =
af::constant((long long)(wordEditDst[i][1]), af::dim4(1, 1, 1, 1));
if (FLAGS_enable_distributed) {
fl::allReduce(currentEditDist);
fl::allReduce(currentTokens);
}
double wer = (double)currentEditDist.scalar<long long>() /
currentTokens.scalar<long long>() * 100.0;
FL_LOG_MASTER(INFO)
<< "[Beam-search decoder] * DM: lmweight=" << lmweights[i]
<< " WER: " << wer;
dmErr = std::min(dmErr, wer);
}
FL_LOG_MASTER(INFO) << "[Beam-search decoder] * DM: done with best WER "
<< dmErr;
timer.stop();
FL_LOG_MASTER(INFO)
<< "[Beam-search decoder] time spent on grid-search for decoding: "
<< timer.value() << "s";
}
for (auto& batch : *curValidset) {
fl::Variable output = ntwrk
->forward(
{fl::input(batch[kInputIdx]),
fl::noGrad(batch[kDurationIdx])})
.front();
std::vector<fl::Variable> critArgs = {
output, fl::Variable(batch[kTargetIdx], false)};
if (isSeq2seqCrit) {
critArgs.push_back(fl::Variable(batch[kDurationIdx], false));
critArgs.push_back(fl::Variable(batch[kTargetSizeIdx], false));
}
auto loss = crit->forward(critArgs).front();
mtrs.loss.add(loss.array());
evalOutput(output.array(), batch[kTargetIdx], batch[kDurationIdx], mtrs);
}
};
int64_t curEpoch = startEpoch;
auto train = [&meters,
&validWerWithDecoder,
&test,
&logStatus,
&saveModels,
&evalOutput,
&validds,
&curEpoch,
&startUpdate,
&isSeq2seqCrit,
&targetTransform,
&tokenToWord,
&targetpadVal,
&plCache,
&plCacheDump,
&plCacheSoft,
&plCacheDumpSoft,
&plBatchCacheFixedSize,
&lexicon,
&tokenDict,
&wordDict,
&worldRank,
reducer](
std::shared_ptr<fl::Module> ntwrk,
std::shared_ptr<fl::Module> ntwrkEMA,
std::shared_ptr<SequenceCriterion> crit,
std::shared_ptr<fl::Dataset> trainset,
std::shared_ptr<fl::Dataset> unsupTrainset,
std::shared_ptr<fl::FirstOrderOptimizer> netopt,
std::shared_ptr<fl::FirstOrderOptimizer> critopt,
double initlr,
double initcritlr,
bool clampCrit,
int64_t nbatches) {
fl::EditDistanceMeter unsupQuality;
meters.train.loss.reset();
meters.trainUnsup.loss.reset();
meters.train.tknEdit.reset();
meters.train.wrdEdit.reset();
meters.trainUnsup.tknEdit.reset();
meters.trainUnsup.wrdEdit.reset();
std::shared_ptr<fl::Module> saug;
std::shared_ptr<fl::Module> saugUnsup;
if (FLAGS_saug_start_update >= 0) {
if (FLAGS_features_type == kFeaturesRaw) {
saugUnsup = std::make_shared<fl::RawWavSpecAugment>(
FLAGS_filterbanks,
FLAGS_saug_fmaskf,
FLAGS_saug_fmaskn,
FLAGS_saug_tmaskt,
FLAGS_saug_tmaskp,
FLAGS_saug_tmaskn,
FLAGS_filterbanks,
FLAGS_lowfreqfilterbank,
FLAGS_highfreqfilterbank,
FLAGS_samplerate);
} else {
saugUnsup = std::make_shared<fl::SpecAugment>(
FLAGS_filterbanks,
FLAGS_saug_fmaskf,
FLAGS_saug_fmaskn,
FLAGS_saug_tmaskt,
FLAGS_saug_tmaskp,
FLAGS_saug_tmaskn);
}
}
if (FLAGS_slimIPL_saug) {
if (FLAGS_features_type == kFeaturesRaw) {
saug = std::make_shared<fl::RawWavSpecAugment>(
FLAGS_filterbanks,
FLAGS_saug_fmaskf,
FLAGS_saug_fmaskn,
FLAGS_saug_tmaskt,
FLAGS_saug_tmaskp,
FLAGS_saug_tmaskn,
FLAGS_filterbanks,
FLAGS_lowfreqfilterbank,
FLAGS_highfreqfilterbank,
FLAGS_samplerate);
} else {
saug = std::make_shared<fl::SpecAugment>(
FLAGS_filterbanks,
FLAGS_saug_fmaskf,
FLAGS_saug_fmaskn + 1,
FLAGS_saug_tmaskt,
FLAGS_saug_tmaskp,
FLAGS_saug_tmaskn * 1.5);
}
} else {
saug = saugUnsup;
}
fl::allReduceParameters(ntwrk);
fl::allReduceParameters(crit);
auto resetTimeStatMeters = [&meters]() {
meters.runtime.reset();
meters.stats.reset();
meters.sampletimer.reset();
meters.fwdtimer.reset();
meters.critfwdtimer.reset();
meters.bwdtimer.reset();
meters.optimtimer.reset();
meters.timer.reset();
};
auto runValAndSaveModel = [&](int64_t totalEpochs,
int64_t totalUpdates,
double lr,
double lrcrit) {
meters.runtime.stop();
meters.timer.stop();
meters.sampletimer.stop();
meters.fwdtimer.stop();
meters.critfwdtimer.stop();
meters.bwdtimer.stop();
meters.optimtimer.stop();
// valid
for (auto& vds : validds) {
double decodedWer;
test(ntwrk, crit, vds.second, meters.valid[vds.first], decodedWer);
if (validWerWithDecoder.find(vds.first) != validWerWithDecoder.end()) {
validWerWithDecoder[vds.first] = decodedWer;
}
}
// print status
try {
logStatus(
meters, validWerWithDecoder, totalEpochs, totalUpdates, lr, lrcrit);
} catch (const std::exception& ex) {
LOG(ERROR) << "Error while writing logs: " << ex.what();
}
// save last and best models
try {
saveModels(totalEpochs, totalUpdates);
} catch (const std::exception& ex) {
LOG(FATAL) << "Error while saving models: " << ex.what();
}
// reset meters for next readings
meters.train.loss.reset();
meters.train.tknEdit.reset();
meters.train.wrdEdit.reset();
meters.trainUnsup.loss.reset();
meters.trainUnsup.tknEdit.reset();
meters.trainUnsup.wrdEdit.reset();
};
int64_t curBatch = startUpdate;
double scaleFactor =
FLAGS_fl_amp_use_mixed_precision ? FLAGS_fl_amp_scale_factor : 1.;
unsigned int kScaleFactorUpdateInterval =
FLAGS_fl_amp_scale_factor_update_interval;
unsigned int kMaxScaleFactor = FLAGS_fl_amp_max_scale_factor;
unsigned short scaleCounter = 1;
bool useUnsup = !(unsupTrainset == nullptr);
FL_LOG_MASTER(INFO) << "Unsup is in use " << useUnsup;
int fixedCacheIndexToLabel = -1;
std::vector<int> unsupBatchesIndices;
if (useUnsup) {
unsupBatchesIndices = std::vector<int>(unsupTrainset->size(), 0);
std::iota(unsupBatchesIndices.begin(), unsupBatchesIndices.end(), 0);
}
int cacheHits =
plBatchCacheFixedSize.size() < FLAGS_slimIPL_fixed_cache_updates
? plBatchCacheFixedSize.size()
: FLAGS_slimIPL_fixed_cache_updates;
std::shared_ptr<fl::Dataset> curUnsupTrainset, curUnsupTrainsetNext;
if (FLAGS_slimIPL_type == "fixed-pre-cache" &&
plBatchCacheFixedSize.size() >= FLAGS_slimIPL_fixed_cache_updates) {
std::random_shuffle(
plBatchCacheFixedSize.begin(), plBatchCacheFixedSize.end());
auto permfn = [plBatchCacheFixedSize](int64_t x) {
return plBatchCacheFixedSize.at(x);
};
curUnsupTrainset = std::make_shared<fl::ResampleDataset>(
unsupTrainset, permfn, plBatchCacheFixedSize.size());
curUnsupTrainset = loadPrefetchDataset(
curUnsupTrainset,
FLAGS_nthread,
false /* shuffle */,
curBatch /* seed */);
}
while (curBatch < nbatches) {
++curEpoch; // counts partial epochs too!
int64_t epochsAfterDecay = curEpoch - FLAGS_lr_decay;
double lrDecayScale = std::pow(
0.5,
(epochsAfterDecay < 0 ? 0
: 1 + epochsAfterDecay / FLAGS_lr_decay_step));
ntwrk->train();
crit->train();
if (FLAGS_reportiters == 0) {
resetTimeStatMeters();
}
std::hash<std::string> hasher;
FL_LOG_MASTER(INFO) << "Shuffling trainset";
auto curTrainset = loadPrefetchDataset(
trainset, FLAGS_nthread, true /* shuffle */, curEpoch /* seed */);
if (useUnsup) {
if (FLAGS_slimIPL_type != "fixed-pre-cache") {
FL_LOG_MASTER(INFO) << "Shuffling unsup trainset";
curUnsupTrainset = loadPrefetchDataset(
unsupTrainset,
FLAGS_nthread,
true /* shuffle */,
curBatch /* seed */);
} else {
FL_LOG_MASTER(INFO) << "Preparing next unsup trainset";
std::random_shuffle(
unsupBatchesIndices.begin(), unsupBatchesIndices.end());
auto permfn = [unsupBatchesIndices](int64_t x) {
return unsupBatchesIndices[x % unsupBatchesIndices.size()];
};
curUnsupTrainsetNext =
std::make_shared<fl::ResampleDataset>(unsupTrainset, permfn);
curUnsupTrainsetNext = loadPrefetchDataset(
curUnsupTrainsetNext,
FLAGS_nthread,
false /* shuffle */,
curBatch /* seed */);
}
}
af::sync();
meters.sampletimer.resume();
meters.runtime.resume();
meters.timer.resume();
FL_LOG_MASTER(INFO) << "Epoch " << curEpoch << " started!";
int unsupBatchIdx = 0, supBatchIdx = 0, setsOrderIdx = 0;
std::vector<bool> setsOrder;
int unsupSteps = useUnsup ? FLAGS_slimIPL_unsup_updates : 0;
for (int index = 0; index < FLAGS_slimIPL_sup_updates + unsupSteps;
index++) {
if (index < FLAGS_slimIPL_sup_updates) {
setsOrder.push_back(true);
} else {
setsOrder.push_back(false);
}
}
std::random_shuffle(setsOrder.begin(), setsOrder.end());
while (supBatchIdx < curTrainset->size()) {
++curBatch;
std::vector<af::array> batch;
bool isSupBatch = setsOrder[setsOrderIdx];
std::vector<std::string> plTextArrayPreCacheToSave;
bool fixedCacheRelabel = true;
if (isSupBatch) {
batch = curTrainset->get(supBatchIdx % curTrainset->size());
LOG(INFO) << "Sup batch " << curBatch << " | " << supBatchIdx << " | "
<< batch[kInputIdx].dims();
++supBatchIdx;
} else {
if (FLAGS_slimIPL_type == "fixed-pre-cache") {
float rNumber =
static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX);
if (plBatchCacheFixedSize.size() <
FLAGS_slimIPL_fixed_cache_updates ||
rNumber < FLAGS_slimIPL_fixed_cache_update_prob) {
fixedCacheIndexToLabel++;
fixedCacheRelabel = true;
} else {
fixedCacheRelabel = false;
}
if (fixedCacheIndexToLabel >= unsupBatchesIndices.size()) {
fixedCacheIndexToLabel = 0;
std::random_shuffle(
unsupBatchesIndices.begin(), unsupBatchesIndices.end());
auto permfn = [unsupBatchesIndices](int64_t x) {
return unsupBatchesIndices[x % unsupBatchesIndices.size()];
};
curUnsupTrainsetNext =
std::make_shared<fl::ResampleDataset>(unsupTrainset, permfn);
curUnsupTrainsetNext = loadPrefetchDataset(
curUnsupTrainsetNext,
FLAGS_nthread,
false /* shuffle */,
curBatch /* seed */);
}
if (cacheHits == FLAGS_slimIPL_fixed_cache_updates) {
// we read the whole cache, time to take another one
cacheHits = 0;
std::random_shuffle(
plBatchCacheFixedSize.begin(), plBatchCacheFixedSize.end());
auto permfn = [plBatchCacheFixedSize](int64_t x) {
return plBatchCacheFixedSize.at(x);
};
curUnsupTrainset = std::make_shared<fl::ResampleDataset>(
unsupTrainset, permfn, plBatchCacheFixedSize.size());
curUnsupTrainset = loadPrefetchDataset(
curUnsupTrainset,
FLAGS_nthread,
false /* shuffle */,
curBatch /* seed */);
}
if (plBatchCacheFixedSize.size() >=
FLAGS_slimIPL_fixed_cache_updates) {
batch =
curUnsupTrainset->get(cacheHits % curUnsupTrainset->size());
if (fixedCacheRelabel) {
if (unsupBatchesIndices[fixedCacheIndexToLabel] < 0) {
LOG(FATAL)
<< "Error in the index which wll be saved into cache";
}
plBatchCacheFixedSize[cacheHits] =
unsupBatchesIndices[fixedCacheIndexToLabel];
}
LOG(INFO) << "Unsup batch " << curBatch << " | " << cacheHits
<< " | " << batch[kInputIdx].dims() << " update cache "
<< fixedCacheRelabel;
} else {
if (unsupBatchesIndices[fixedCacheIndexToLabel] < 0) {
LOG(FATAL)
<< "Error in the index while preparing first state of cache";
}
plBatchCacheFixedSize.push_back(
unsupBatchesIndices[fixedCacheIndexToLabel]);
LOG(INFO)
<< "Skip usage of unsup batch as fixed cache is not ready "
<< curBatch;
}
cacheHits++;
} else {
if (unsupBatchIdx < 0) {
LOG(FATAL) << "index data is negative "
<< "unsupBatchIdx" << unsupBatchIdx;
}
batch =
curUnsupTrainset->get(unsupBatchIdx % curUnsupTrainset->size());
LOG(INFO) << "Unsup batch " << curBatch << " | " << unsupBatchIdx
<< " | " << batch[kInputIdx].dims();
++unsupBatchIdx;
if (unsupBatchIdx >= curUnsupTrainset->size()) {
unsupBatchIdx = 0;
FL_LOG_MASTER(INFO) << "Shuffling unsup trainset";
curUnsupTrainset = loadPrefetchDataset(
unsupTrainset,
FLAGS_nthread,
true /* shuffle */,
curBatch /* seed */);
}
}
}
setsOrderIdx++;
if (setsOrderIdx >= setsOrder.size()) {
setsOrderIdx = 0;
std::random_shuffle(setsOrder.begin(), setsOrder.end());
}
double lrScheduleScale;
if (FLAGS_lrcosine) {
const double pi = std::acos(-1);
lrScheduleScale =
std::cos(((double)curBatch) / ((double)nbatches) * pi / 2.0);
} else {
lrScheduleScale =
std::pow(FLAGS_gamma, (double)curBatch / (double)FLAGS_stepsize);
}
netopt->setLr(
initlr * lrDecayScale * lrScheduleScale *
std::min(curBatch / double(FLAGS_warmup), 1.0));
critopt->setLr(
initcritlr * lrDecayScale * lrScheduleScale *
std::min(curBatch / double(FLAGS_warmup), 1.0));
af::sync();
meters.timer.incUnit();
meters.sampletimer.stopAndIncUnit();
if (isSupBatch) {
meters.stats.add(batch[kDurationIdx], batch[kTargetSizeIdx]);
}
if (!batch.empty() &&
(af::anyTrue<bool>(af::isNaN(batch[kInputIdx])) ||
af::anyTrue<bool>(af::isNaN(batch[kTargetIdx])))) {
LOG(FATAL) << "Sample has NaN values - "
<< join(",", readSampleIds(batch[kSampleIdx]));
}
auto predictPLCommon = [&](std::vector<af::array> inpBatch)
-> std::pair<std::vector<std::string>, af::array> {
ntwrkEMA->eval();
crit->eval();
auto outputUnsupOriginal =
ntwrkEMA
->forward(
{fl::input(inpBatch[kInputIdx]),
fl::noGrad(inpBatch[kDurationIdx])})
.front()
.array();
std::vector<std::vector<int>> tokenPredictions;
auto viterbiPath =
crit->viterbiPath(outputUnsupOriginal, inpBatch[kDurationIdx]);
for (int index = 0; index < viterbiPath.dims(1); index++) {
tokenPredictions.push_back(afToVector<int>(viterbiPath.col(index)));
}
ntwrkEMA->train();
crit->train();
unsupQuality.reset();
if (curBatch % 100 == 0) {
FL_LOG_MASTER(INFO)
<< "PL for samples "
<< join(",", readSampleIds(inpBatch[kSampleIdx]));
}
std::vector<std::string> plTextArray;
for (int index = 0; index < tokenPredictions.size(); index++) {
auto tokenPrediction = tokenPredictions[index];
auto plArray = tokenToWord(tokenPrediction, true);
auto plTrueArray = tokenToWord(
afToVector<int>(inpBatch[kTargetIdx].col(index)), false);
auto plText = fl::lib::join(" ", plArray);
if (curBatch % 100 == 0) {
FL_LOG_MASTER(INFO) << "PL for index " << index << ": " << plText;
}
unsupQuality.add(plArray, plTrueArray);
plTextArray.push_back(plText);
}
fl::pkg::runtime::syncMeter(unsupQuality);
if (fl::getWorldRank() == 0) {
std::cout << "PL Quality for Batch " << curBatch << " : "
<< unsupQuality.errorRate()[0] << std::endl;
}
return {plTextArray, outputUnsupOriginal};
};
auto predictPL =
[&](std::vector<af::array> inpBatch) -> std::vector<std::string> {
return predictPLCommon(inpBatch).first;
};
auto predictSoftPL = [&](std::vector<af::array> inpBatch) -> af::array {
return predictPLCommon(inpBatch).second;
};
// Ensure no samples are skipped while adjusting the loss scale factor.
// When gradient values are Inf/NaN, the model update is skipped and the
// scale factor is adjusted accordingly for determinism.
// The AMP algorithm implemented here mirrors:
// - https://arxiv.org/abs/1710.03740
// - https://bit.ly/35F5GqX
// - https://bit.ly/3mn2qr0
bool retrySample = false;
bool doUpdate = true;
std::vector<std::string> samplesIndices;
if (!batch.empty()) {
samplesIndices = readSampleIds(batch[kSampleIdx]);
}
do {
retrySample = false;
std::vector<fl::Variable> critArgs;
fl::Variable input, output;
af::array newUnsupDuration;
// forward
meters.fwdtimer.resume();
if (!batch.empty()) {
if (batch[kInputIdx].dims(3) > 60) {
int length = 59;
LOG(INFO) << "Shrink batch, too huge " << batch[kInputIdx].dims()
<< " | " << batch[kDurationIdx].dims() << " | "
<< batch[kTargetIdx].dims() << " | "
<< batch[kSampleIdx].dims();
batch[kInputIdx] = af::reorder(
af::reorder(batch[kInputIdx], 0, 3, 2, 1).cols(0, length),
0,
3,
2,
1);
batch[kDurationIdx] = batch[kDurationIdx].cols(0, length);
batch[kTargetIdx] = batch[kTargetIdx].cols(0, length);
batch[kSampleIdx] = batch[kSampleIdx].cols(0, length);
}
input = fl::input(batch[kInputIdx]);
if (FLAGS_saug_start_update >= 0 &&
curBatch >= FLAGS_saug_start_update) {
if (isSupBatch) {
input = saug->forward({input}).front();
} else {
input = saugUnsup->forward({input}).front();
}
}
std::vector<fl::Variable> fwdParams = {
input, fl::noGrad(batch[kDurationIdx])};
if (FLAGS_slimIPL_dyn_dropout >= 0 && useUnsup) {
fwdParams.push_back(fl::noGrad(
af::constant(FLAGS_slimIPL_dyn_dropout, af::dim4(1))));
}
output = ntwrk->forward(fwdParams).front();
}
if (isSupBatch) {
critArgs = {output, fl::Variable(batch[kTargetIdx], false)};
if (isSeq2seqCrit) {
critArgs.push_back(fl::Variable(batch[kDurationIdx], false));
critArgs.push_back(fl::Variable(batch[kTargetSizeIdx], false));
}
} else {
// unsup Batch
std::vector<std::string> plTextArray;
std::vector<af::array> plSoftArray;
std::vector<int> useUnsupSamplesIndices;
if (FLAGS_slimIPL_use_soft &&
FLAGS_slimIPL_type == "fixed-pre-cache") {
if (!batch.empty()) {
for (int index = 0; index < samplesIndices.size(); index++) {
if (plCacheSoft.find(samplesIndices[index]) ==
plCacheSoft.end() &&
plCacheDumpSoft.find(samplesIndices[index]) !=
plCacheDumpSoft.end()) {
LOG(INFO)
<< "Reuse extra loaded soft cache for sample "
<< samplesIndices[index] << " for batch " << curBatch;
plCacheSoft[samplesIndices[index]] =
plCacheDumpSoft[samplesIndices[index]];
}
}
for (int index = 0; index < samplesIndices.size(); index++) {
if (plCacheDumpSoft.find(samplesIndices[index]) !=
plCacheDumpSoft.end()) { // place to add filtering too
useUnsupSamplesIndices.push_back(index);
plSoftArray.push_back(
plCacheDumpSoft[samplesIndices[index]]);
}
}
}
if (FLAGS_slimIPL_type == "fixed-pre-cache" &&
fixedCacheRelabel) {
if (unsupBatchIdx < 0) {
LOG(FATAL)
<< "index data is negative "
<< "fixedCacheIndexToLabel" << fixedCacheIndexToLabel;
}
auto nextBatch = curUnsupTrainsetNext->get(
fixedCacheIndexToLabel % curUnsupTrainsetNext->size());
auto samplesIndicesNext = readSampleIds(nextBatch[kSampleIdx]);
auto plSoftArrayPreCache = predictSoftPL(nextBatch);
for (int index = 0; index < plSoftArrayPreCache.dims(2);
index++) {
plCacheSoft[samplesIndicesNext[index]] =
plSoftArrayPreCache(af::span, af::span, index);
}
}
if (useUnsupSamplesIndices.size() > 0) {
af::array maskedSamples = af::array(
af::dim4(useUnsupSamplesIndices.size()),
useUnsupSamplesIndices.data());
std::vector<fl::Variable> newTargets;
for (int index = 0; index < useUnsupSamplesIndices.size();
index++) {
newTargets.push_back(fl::Variable(
plCacheSoft
[samplesIndices[useUnsupSamplesIndices[index]]],
false));
}
meters.stats.add(
batch[kDurationIdx](maskedSamples),
batch[kTargetSizeIdx](maskedSamples));
critArgs = {
output(af::span, af::span, maskedSamples, af::span, true),
fl::concatenate(newTargets, 2)};
if (isSeq2seqCrit) {
critArgs.push_back(
fl::Variable(batch[kDurationIdx](maskedSamples), false));
critArgs.push_back(fl::Variable(
batch[kTargetSizeIdx](maskedSamples), false));
}
} else {
LOG(INFO)
<< "Skip unsupervised part of data as PL are not available yet";
}
} else {
if (FLAGS_slimIPL_type == "naive") {
// generate by current NN the labels;
plTextArray = predictPL(batch);
useUnsupSamplesIndices = std::vector<int>(plTextArray.size());
std::iota(
useUnsupSamplesIndices.begin(),
useUnsupSamplesIndices.end(),
0);
} else if (
FLAGS_slimIPL_type == "cache" ||
FLAGS_slimIPL_type == "pre-cache" ||
(FLAGS_slimIPL_type == "fixed-pre-cache" && !batch.empty())) {
for (int index = 0; index < samplesIndices.size(); index++) {
if (plCache.find(samplesIndices[index]) == plCache.end() &&
plCacheDump.find(samplesIndices[index]) !=
plCacheDump.end()) {
LOG(INFO)
<< "Reuse extra loaded cache for sample "
<< samplesIndices[index] << " for batch " << curBatch;
plCache[samplesIndices[index]] =
plCacheDump[samplesIndices[index]];
}
}
for (int index = 0; index < samplesIndices.size(); index++) {
if (plCache.find(samplesIndices[index]) !=
plCache.end()) { // place to add filtering too
useUnsupSamplesIndices.push_back(index);
plTextArray.push_back(plCache[samplesIndices[index]]);
}
}
if (FLAGS_slimIPL_type == "pre-cache" ||
useUnsupSamplesIndices.size() == 0) {
// update Cache before doing model update
plTextArrayPreCacheToSave = predictPL(batch);
}
}
if (FLAGS_slimIPL_type == "fixed-pre-cache" &&
fixedCacheRelabel) {
if (unsupBatchIdx < 0) {
LOG(FATAL)
<< "index data is negative "
<< "fixedCacheIndexToLabel" << fixedCacheIndexToLabel;
}
auto nextBatch = curUnsupTrainsetNext->get(
fixedCacheIndexToLabel % curUnsupTrainsetNext->size());
auto samplesIndicesNext = readSampleIds(nextBatch[kSampleIdx]);
auto plTextArrayPreCache = predictPL(nextBatch);
for (int index = 0; index < plTextArrayPreCache.size();
index++) {
plCache[samplesIndicesNext[index]] =
plTextArrayPreCache[index];
}
}
if (useUnsupSamplesIndices.size() > 0) {
af::array maskedSamples = af::array(
af::dim4(useUnsupSamplesIndices.size()),
useUnsupSamplesIndices.data());
std::vector<af::array> newTargets, newTargetsSize;
for (auto& plText : plTextArray) {
std::vector<char> curTarget(plText.begin(), plText.end());
auto target = targetTransform(
static_cast<void*>(curTarget.data()),
{static_cast<dim_t>(curTarget.size())},
af::dtype::b8);
newTargets.push_back(target);
newTargetsSize.push_back(
af::constant(float(target.elements()), 1));
}
fl::Dataset::BatchFunction fnc =
[targetpadVal](const std::vector<af::array>& arr) {
return fl::join(arr, targetpadVal, 1);
};
auto newTargetsSizeBatch =
fl::makeBatch(newTargetsSize, nullptr);
auto newTargetsBatch = fl::makeBatch(newTargets, fnc);
meters.stats.add(
batch[kDurationIdx](maskedSamples), newTargetsSizeBatch);
// TODO optimize masking early
critArgs = {
output(af::span, af::span, maskedSamples, af::span, true),
fl::Variable(newTargetsBatch, false)};
newUnsupDuration = batch[kDurationIdx](maskedSamples);
if (isSeq2seqCrit) {
critArgs.push_back(fl::Variable(newUnsupDuration, false));
critArgs.push_back(fl::Variable(newTargetsSizeBatch, false));
}
} else {
LOG(INFO)
<< "Skip unsupervised part of data as PL are not available yet";
}
}
}
float r = critArgs.size() > 0;
af::array doUpdateArr = af::array(1, &r);
af::sync();
if (FLAGS_enable_distributed) {
fl::allReduce(doUpdateArr);
}
if (af::sum<int>(doUpdateArr) < fl::getWorldSize()) {
doUpdate = false;
break;
}
meters.critfwdtimer.resume();
fl::Variable loss;
if (!isSupBatch && FLAGS_slimIPL_use_soft) {
af::print("target", critArgs[1].array());
af::print("pred", critArgs[0].array());
loss = FLAGS_slimIPL_soft_scale *
fl::negate(fl::mean(
fl::sum(
fl::softmax(critArgs[1].as(f32), 0) *
fl::logSoftmax(critArgs[0].as(f32), 0),
{0}),
{1, 2, 3}));
} else {
loss = crit->forward(critArgs).front();
}
af::sync();
meters.fwdtimer.stopAndIncUnit();
meters.critfwdtimer.stopAndIncUnit();
if (FLAGS_fl_amp_use_mixed_precision) {
++scaleCounter;
loss = loss * scaleFactor;
}
if (af::anyTrue<bool>(af::isNaN(loss.array())) ||
af::anyTrue<bool>(af::isInf(loss.array()))) {
if (af::anyTrue<bool>(af::isInf(critArgs[0].array()))) {
LOG(INFO) << "input to crit has Inf values. Samples - "
<< join(",", readSampleIds(batch[kSampleIdx]));
}
if (af::anyTrue<bool>(af::isNaN(critArgs[0].array()))) {
LOG(INFO) << "input to crit has NaN values. Samples - "
<< join(",", readSampleIds(batch[kSampleIdx]));
}
LOG(FATAL) << "Loss has NaN values. Samples - "
<< join(",", readSampleIds(batch[kSampleIdx]));
}
if (hasher(join(",", readSampleIds(batch[kSampleIdx]))) % 100 <=
FLAGS_pcttraineval) {
if (isSupBatch) {
evalOutput(
critArgs[0].array(),
critArgs[1].array(),
batch[kDurationIdx],
meters.train);
} else {
evalOutput(
critArgs[0].array(),
critArgs[1].array(),
newUnsupDuration,
meters.trainUnsup);
}
}
// backward
meters.bwdtimer.resume();
netopt->zeroGrad();
critopt->zeroGrad();
loss.backward();
if (reducer) {
for (auto& p : ntwrk->params()) {
if (!p.isGradAvailable()) {
p.addGrad(fl::constant(0.0, p.dims(), p.type(), false));
}
reducer->add(p.grad());
}
for (auto& p : crit->params()) {
if (!p.isGradAvailable()) {
p.addGrad(fl::constant(0.0, p.dims(), p.type(), false));
}
reducer->add(p.grad());
}
reducer->finalize();
}
af::sync();
meters.bwdtimer.stopAndIncUnit();
// optimizer
meters.optimtimer.resume();
// scale down gradients by batchsize
af::array totalBatchSizeArr = af::constant(loss.dims(0), 1, f32);
if (reducer) {
fl::allReduce(totalBatchSizeArr);
}
float totalBatchSize = totalBatchSizeArr.scalar<float>();
for (const auto& p : ntwrk->params()) {
if (!p.isGradAvailable()) {
continue;
}
p.grad() = p.grad() / (totalBatchSize * scaleFactor);
if (FLAGS_fl_amp_use_mixed_precision) {
if (af::anyTrue<bool>(af::isNaN(p.grad().array())) ||
af::anyTrue<bool>(af::isInf(p.grad().array()))) {
if (scaleFactor >= fl::kAmpMinimumScaleFactorValue) {
scaleFactor = scaleFactor / 2.0f;
FL_VLOG(2) << "AMP: Scale factor decreased. New value:\t"
<< scaleFactor;
retrySample = true;
}
scaleCounter = 1;
break;
}
}
}
if (retrySample) {
LOG(INFO) << "Retry amp sample " << scaleFactor;
meters.optimtimer.stop();
continue;
}
if (isSupBatch) {
meters.train.loss.add((loss / scaleFactor).array());
} else {
meters.trainUnsup.loss.add((loss / scaleFactor).array());
}
for (const auto& p : crit->params()) {
if (!p.isGradAvailable()) {
continue;
}
p.grad() = p.grad() / (totalBatchSize * scaleFactor);
}
} while (retrySample);
for (int index = 0; index < plTextArrayPreCacheToSave.size(); index++) {
plCache[samplesIndices[index]] = plTextArrayPreCacheToSave[index];
}
if (doUpdate) {
// clamp gradients
if (FLAGS_maxgradnorm > 0) {
auto params = ntwrk->params();
if (clampCrit) {
auto critparams = crit->params();
params.insert(params.end(), critparams.begin(), critparams.end());
}
fl::clipGradNorm(params, FLAGS_maxgradnorm);
}
// update weights
critopt->step();
netopt->step();
af::sync();
meters.optimtimer.stopAndIncUnit();
// update scale factor
if (FLAGS_fl_amp_use_mixed_precision &&
scaleFactor < kMaxScaleFactor) {
if (scaleCounter % kScaleFactorUpdateInterval == 0) {
scaleFactor *= 2;
FL_VLOG(2) << "AMP: Scale factor doubled. New value:\t"
<< scaleFactor;
} else {
scaleFactor += 2;
FL_VLOG(3) << "AMP: Scale factor incremented. New value\t"
<< scaleFactor;
}
}
} else {
LOG(INFO) << "Skip update step as unsup data has no label "
<< curBatch;
}
// update EMA model
if (FLAGS_slimIPL_ema) {
for (int i = 0; i < ntwrkEMA->params().size(); ++i) {
af::array newParam =
ntwrkEMA->param(i).array() * FLAGS_slimIPL_ema_decay +
ntwrk->param(i).array() * (1 - FLAGS_slimIPL_ema_decay);
newParam.eval();
ntwrkEMA->setParams(fl::Variable(newParam, false), i);
}
}
if (!isSupBatch) {
if (FLAGS_slimIPL_type == "cache") {
auto plTextArray = predictPL(batch);
auto samplesIndices = readSampleIds(batch[kSampleIdx]);
for (int index = 0; index < plTextArray.size(); index++) {
plCache[samplesIndices[index]] = plTextArray[index];
}
}
}
meters.sampletimer.resume();
if (FLAGS_reportiters > 0 && curBatch % FLAGS_reportiters == 0) {
runValAndSaveModel(
curEpoch, curBatch, netopt->getLr(), critopt->getLr());
resetTimeStatMeters();
ntwrk->train();
crit->train();
meters.sampletimer.resume();
meters.runtime.resume();
meters.timer.resume();
}
if (curBatch > nbatches) {
break;
}
}
af::sync();
if (FLAGS_reportiters == 0) {
runValAndSaveModel(
curEpoch, curBatch, netopt->getLr(), critopt->getLr());
}
}
};
/* ===================== Train ===================== */
if (FLAGS_linseg - startUpdate > 0) {
train(
network,
networkEMA,
linseg,
trainds,
nullptr,
linNetoptim,
linCritoptim,
initLinNetlr,
initLinCritlr,
false /* clampCrit */,
FLAGS_linseg - startUpdate);
startUpdate = FLAGS_linseg;
FL_LOG_MASTER(INFO) << "Finished LinSeg";
}
auto s2s = std::dynamic_pointer_cast<Seq2SeqCriterion>(criterion);
auto trde = std::dynamic_pointer_cast<TransformerCriterion>(criterion);
if (FLAGS_pretrainWindow - startUpdate > 0) {
if (!s2s && !trde) {
LOG(FATAL) << "Window pretraining only allowed for seq2seq.";
}
train(
network,
networkEMA,
criterion,
trainds,
nullptr,
netoptim,
critoptim,
FLAGS_lr,
FLAGS_lrcrit,
true,
FLAGS_pretrainWindow);
startUpdate = FLAGS_pretrainWindow;
FL_LOG_MASTER(INFO) << "Finished window pretraining.";
}
if (s2s) {
s2s->clearWindow();
} else if (trde) {
trde->clearWindow();
}
if (FLAGS_slimIPL_start - startUpdate > 0) {
train(
network,
networkEMA,
criterion,
trainds,
nullptr,
netoptim,
critoptim,
FLAGS_lr,
FLAGS_lrcrit,
true /* clampCrit */,
FLAGS_slimIPL_start);
startUpdate = FLAGS_slimIPL_start;
FL_LOG_MASTER(INFO) << "Finished supervised only pretraining.";
}
train(
network,
networkEMA,
criterion,
trainds,
unsupTrainds,
netoptim,
critoptim,
FLAGS_lr,
FLAGS_lrcrit,
true /* clampCrit */,
FLAGS_iter);
FL_LOG_MASTER(INFO) << "Finished training";
return 0;
}