Real EmbedModel::train()

in src/model.cpp [160:341]


Real EmbedModel::train(shared_ptr<InternDataHandler> data,
                       int numThreads,
                      std::chrono::time_point<std::chrono::high_resolution_clock> t_start,
                      int epochs_done,
                       Real rate,
                       Real finishRate,
                       bool verbose) {
  assert(rate >= finishRate);
  assert(rate >= 0.0);

  // Use a layer of indirection when accessing the corpus to allow shuffling.
  auto numSamples = data->getSize();
  vector<int> indices(numSamples);
  {
    int i = 0;
    for (auto& idx: indices) idx = i++;
  }
  std::random_shuffle(indices.begin(), indices.end());

  // Compute word negatives
  if (args_->trainMode == 5 || args_->trainWord) {
    data->initWordNegatives();
  }

  // If we decrement after *every* sample, precision causes us to lose the
  // update.
  const int kDecrStep = 1000;
  auto decrPerKSample = (rate - finishRate) / (numSamples / kDecrStep);
  const Real negSearchLimit = (std::min)(numSamples,
                                       size_t(args_->negSearchLimit));

  numThreads = (std::max)(numThreads, 2);
  numThreads -= 1; // Withold one thread for the norm thread.
  numThreads = (std::min)(numThreads, int(numSamples));
  vector<Real> losses(numThreads);
  vector<long> counts(numThreads);

  auto trainThread = [&](int idx,
                         vector<int>::const_iterator start,
                         vector<int>::const_iterator end) {
    assert(start >= indices.begin());
    assert(end >= start);
    assert(end <= indices.end());
    bool amMaster = idx == 0;
    auto t_epoch_start = std::chrono::high_resolution_clock::now();
    losses[idx] = 0.0;
    counts[idx] = 0;

    unsigned int batch_sz = args_->batchSize;
    vector<ParseResults> examples;
    for (auto ip = start; ip < end; ip++) {
      auto i = *ip;
      float thisLoss = 0.0;
      if (args_->trainMode == 5 || args_->trainWord) {
        vector<ParseResults> exs;
        data->getWordExamples(i, exs);
        vector<ParseResults> word_exs;
        for (unsigned int i = 0; i < exs.size(); i++) {
          word_exs.push_back(exs[i]);
          if (word_exs.size() >= batch_sz || i == exs.size() - 1) {
            if (args_->loss == "softmax") {
              thisLoss = trainNLLBatch(data, word_exs, negSearchLimit, rate, true);
            } else {
              thisLoss = trainOneBatch(data, word_exs, negSearchLimit, rate, true);
            }
            word_exs.clear();
            assert(thisLoss >= 0.0);
            counts[idx]++;
            losses[idx] += thisLoss;
          }
        }
      }
      if (args_->trainMode != 5) {
        ParseResults ex;
        data->getExampleById(i, ex);
        if (ex.LHSTokens.size() == 0 or ex.RHSTokens.size() == 0) {
          continue;
        }
        examples.push_back(ex);
        if (examples.size() >= batch_sz || (ip + 1) == end) {
          if (args_->loss == "softmax") {
            thisLoss = trainNLLBatch(data, examples, negSearchLimit, rate, false);
          } else {
            thisLoss = trainOneBatch(data, examples, negSearchLimit, rate, false);
          }
          examples.clear();

          assert(thisLoss >= 0.0);
          counts[idx]++;
          losses[idx] += thisLoss;
        }
      }

      // update rate racily.
      if ((i % kDecrStep) == (kDecrStep - 1)) {
        rate -= decrPerKSample;
      }
      auto t_end = std::chrono::high_resolution_clock::now();
      auto tot_spent = std::chrono::duration<double>(t_end-t_start).count();
      if (tot_spent > args_->maxTrainTime) {
        break;
      }
      if (amMaster && ((ip - indices.begin()) % 100 == 99 || (ip + 1) == end)) {

        auto t_epoch_spent =
          std::chrono::duration<double>(t_end-t_epoch_start).count();
        double ex_done_this_epoch = ip - indices.begin();
        int ex_left = ((end - start) * (args_->epoch - epochs_done))
                      - ex_done_this_epoch;
        double ex_done = epochs_done * (end - start) + ex_done_this_epoch;
        double time_per_ex = double(t_epoch_spent) / ex_done_this_epoch;
        int eta = int(time_per_ex * double(ex_left));
        double epoch_progress = ex_done_this_epoch / (end - start);
        double progress = ex_done / (ex_done + ex_left);
        if (eta > args_->maxTrainTime - tot_spent) {
          eta = args_->maxTrainTime - tot_spent;
          progress = tot_spent / (eta + tot_spent);
        }
        int etah = eta / 3600;
        int etam = (eta - etah * 3600) / 60;
        int toth = int(tot_spent) / 3600;
        int totm = (tot_spent - toth * 3600) / 60;
        int tots = (tot_spent - toth * 3600 - totm * 60);
        std::cerr << std::fixed;
        std::cerr << "\rEpoch: " << std::setprecision(1) << 100 * epoch_progress << "%";
        std::cerr << "  lr: " << std::setprecision(6) << rate;
        std::cerr << "  loss: " << std::setprecision(6) << losses[idx] / counts[idx];
        if (eta < 60) {
          std::cerr << "  eta: <1min ";
        } else {
          std::cerr << "  eta: " << std::setprecision(3) << etah << "h" << etam << "m";
        }
        std::cerr << "  tot: " << std::setprecision(3) << toth << "h" << totm << "m"  << tots << "s ";
        std::cerr << " (" << std::setprecision(1) << 100 * progress << "%)";
        std::cerr << std::flush;
      }
    }
  };

  vector<thread> threads;
  bool doneTraining = false;
  size_t numPerThread = ceil(numSamples / numThreads);
  assert(numPerThread > 0);
  for (size_t i = 0; i < (size_t)numThreads; i++) {
    auto start = i * numPerThread;
    auto end = (std::min)(start + numPerThread, numSamples);
    assert(end >= start);
    assert(end <= numSamples);
    auto b = indices.begin() + start;
    auto e = indices.begin() + end;
    assert(b >= indices.begin());
    assert(e >= b);
    assert(e <= indices.end());
    threads.emplace_back(thread([=] {
      trainThread(i, b, e);
    }));
  }

  // .. and a norm truncation thread. It's not worth it to slow
  // down every update with truncation, so just work our way through
  // truncating as needed on a separate thread.
  std::thread truncator([&] {
    auto trunc = [](Matrix<Real>::Row row, double maxNorm) {
      auto norm = norm2(row);
      if (norm > maxNorm) {
        row *= (maxNorm / norm);
      }
    };
    for (int i = 0; !doneTraining; i++) {
      auto wIdx = i % LHSEmbeddings_->numRows();
      trunc(LHSEmbeddings_->row(wIdx), args_->norm);
    }
  });
  for (auto& t: threads) t.join();
  // All done. Shut the truncator down.
  doneTraining = true;
  truncator.join();

  Real totLoss = std::accumulate(losses.begin(), losses.end(), 0.0);
  long totCount = std::accumulate(counts.begin(), counts.end(), 0);
  return totLoss / totCount;
}