in src/fasttext.cc [632:678]
void FastText::trainThread(int32_t threadId, const TrainCallback& callback) {
std::ifstream ifs(args_->input);
utils::seek(ifs, threadId * utils::size(ifs) / args_->thread);
Model::State state(args_->dim, output_->size(0), threadId + args_->seed);
const int64_t ntokens = dict_->ntokens();
int64_t localTokenCount = 0;
std::vector<int32_t> line, labels;
uint64_t callbackCounter = 0;
try {
while (keepTraining(ntokens)) {
real progress = real(tokenCount_) / (args_->epoch * ntokens);
if (callback && ((callbackCounter++ % 64) == 0)) {
double wst;
double lr;
int64_t eta;
std::tie<double, double, int64_t>(wst, lr, eta) =
progressInfo(progress);
callback(progress, loss_, wst, lr, eta);
}
real lr = args_->lr * (1.0 - progress);
if (args_->model == model_name::sup) {
localTokenCount += dict_->getLine(ifs, line, labels);
supervised(state, lr, line, labels);
} else if (args_->model == model_name::cbow) {
localTokenCount += dict_->getLine(ifs, line, state.rng);
cbow(state, lr, line);
} else if (args_->model == model_name::sg) {
localTokenCount += dict_->getLine(ifs, line, state.rng);
skipgram(state, lr, line);
}
if (localTokenCount > args_->lrUpdateRate) {
tokenCount_ += localTokenCount;
localTokenCount = 0;
if (threadId == 0 && args_->verbose > 1) {
loss_ = state.getLoss();
}
}
}
} catch (DenseMatrix::EncounteredNaNError&) {
trainException_ = std::current_exception();
}
if (threadId == 0)
loss_ = state.getLoss();
ifs.close();
}