in src/model.cpp [160:341]
Real EmbedModel::train(shared_ptr<InternDataHandler> data,
int numThreads,
std::chrono::time_point<std::chrono::high_resolution_clock> t_start,
int epochs_done,
Real rate,
Real finishRate,
bool verbose) {
assert(rate >= finishRate);
assert(rate >= 0.0);
// Use a layer of indirection when accessing the corpus to allow shuffling.
auto numSamples = data->getSize();
vector<int> indices(numSamples);
{
int i = 0;
for (auto& idx: indices) idx = i++;
}
std::random_shuffle(indices.begin(), indices.end());
// Compute word negatives
if (args_->trainMode == 5 || args_->trainWord) {
data->initWordNegatives();
}
// If we decrement after *every* sample, precision causes us to lose the
// update.
const int kDecrStep = 1000;
auto decrPerKSample = (rate - finishRate) / (numSamples / kDecrStep);
const Real negSearchLimit = (std::min)(numSamples,
size_t(args_->negSearchLimit));
numThreads = (std::max)(numThreads, 2);
numThreads -= 1; // Withold one thread for the norm thread.
numThreads = (std::min)(numThreads, int(numSamples));
vector<Real> losses(numThreads);
vector<long> counts(numThreads);
auto trainThread = [&](int idx,
vector<int>::const_iterator start,
vector<int>::const_iterator end) {
assert(start >= indices.begin());
assert(end >= start);
assert(end <= indices.end());
bool amMaster = idx == 0;
auto t_epoch_start = std::chrono::high_resolution_clock::now();
losses[idx] = 0.0;
counts[idx] = 0;
unsigned int batch_sz = args_->batchSize;
vector<ParseResults> examples;
for (auto ip = start; ip < end; ip++) {
auto i = *ip;
float thisLoss = 0.0;
if (args_->trainMode == 5 || args_->trainWord) {
vector<ParseResults> exs;
data->getWordExamples(i, exs);
vector<ParseResults> word_exs;
for (unsigned int i = 0; i < exs.size(); i++) {
word_exs.push_back(exs[i]);
if (word_exs.size() >= batch_sz || i == exs.size() - 1) {
if (args_->loss == "softmax") {
thisLoss = trainNLLBatch(data, word_exs, negSearchLimit, rate, true);
} else {
thisLoss = trainOneBatch(data, word_exs, negSearchLimit, rate, true);
}
word_exs.clear();
assert(thisLoss >= 0.0);
counts[idx]++;
losses[idx] += thisLoss;
}
}
}
if (args_->trainMode != 5) {
ParseResults ex;
data->getExampleById(i, ex);
if (ex.LHSTokens.size() == 0 or ex.RHSTokens.size() == 0) {
continue;
}
examples.push_back(ex);
if (examples.size() >= batch_sz || (ip + 1) == end) {
if (args_->loss == "softmax") {
thisLoss = trainNLLBatch(data, examples, negSearchLimit, rate, false);
} else {
thisLoss = trainOneBatch(data, examples, negSearchLimit, rate, false);
}
examples.clear();
assert(thisLoss >= 0.0);
counts[idx]++;
losses[idx] += thisLoss;
}
}
// update rate racily.
if ((i % kDecrStep) == (kDecrStep - 1)) {
rate -= decrPerKSample;
}
auto t_end = std::chrono::high_resolution_clock::now();
auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
if (tot_spent > args_->maxTrainTime) {
break;
}
if (amMaster && ((ip - indices.begin()) % 100 == 99 || (ip + 1) == end)) {
auto t_epoch_spent =
std::chrono::duration<double>(t_end-t_epoch_start).count();
double ex_done_this_epoch = ip - indices.begin();
int ex_left = ((end - start) * (args_->epoch - epochs_done))
- ex_done_this_epoch;
double ex_done = epochs_done * (end - start) + ex_done_this_epoch;
double time_per_ex = double(t_epoch_spent) / ex_done_this_epoch;
int eta = int(time_per_ex * double(ex_left));
double epoch_progress = ex_done_this_epoch / (end - start);
double progress = ex_done / (ex_done + ex_left);
if (eta > args_->maxTrainTime - tot_spent) {
eta = args_->maxTrainTime - tot_spent;
progress = tot_spent / (eta + tot_spent);
}
int etah = eta / 3600;
int etam = (eta - etah * 3600) / 60;
int toth = int(tot_spent) / 3600;
int totm = (tot_spent - toth * 3600) / 60;
int tots = (tot_spent - toth * 3600 - totm * 60);
std::cerr << std::fixed;
std::cerr << "\rEpoch: " << std::setprecision(1) << 100 * epoch_progress << "%";
std::cerr << " lr: " << std::setprecision(6) << rate;
std::cerr << " loss: " << std::setprecision(6) << losses[idx] / counts[idx];
if (eta < 60) {
std::cerr << " eta: <1min ";
} else {
std::cerr << " eta: " << std::setprecision(3) << etah << "h" << etam << "m";
}
std::cerr << " tot: " << std::setprecision(3) << toth << "h" << totm << "m" << tots << "s ";
std::cerr << " (" << std::setprecision(1) << 100 * progress << "%)";
std::cerr << std::flush;
}
}
};
vector<thread> threads;
bool doneTraining = false;
size_t numPerThread = ceil(numSamples / numThreads);
assert(numPerThread > 0);
for (size_t i = 0; i < (size_t)numThreads; i++) {
auto start = i * numPerThread;
auto end = (std::min)(start + numPerThread, numSamples);
assert(end >= start);
assert(end <= numSamples);
auto b = indices.begin() + start;
auto e = indices.begin() + end;
assert(b >= indices.begin());
assert(e >= b);
assert(e <= indices.end());
threads.emplace_back(thread([=] {
trainThread(i, b, e);
}));
}
// .. and a norm truncation thread. It's not worth it to slow
// down every update with truncation, so just work our way through
// truncating as needed on a separate thread.
std::thread truncator([&] {
auto trunc = [](Matrix<Real>::Row row, double maxNorm) {
auto norm = norm2(row);
if (norm > maxNorm) {
row *= (maxNorm / norm);
}
};
for (int i = 0; !doneTraining; i++) {
auto wIdx = i % LHSEmbeddings_->numRows();
trunc(LHSEmbeddings_->row(wIdx), args_->norm);
}
});
for (auto& t: threads) t.join();
// All done. Shut the truncator down.
doneTraining = true;
truncator.join();
Real totLoss = std::accumulate(losses.begin(), losses.end(), 0.0);
long totCount = std::accumulate(counts.begin(), counts.end(), 0);
return totLoss / totCount;
}