void EmbedModel::backward()

in src/model.cpp [487:561]


void EmbedModel::backward(
    const vector<ParseResults>& batch_exs,
    const vector<vector<Base>>& batch_negLabels,
    vector<Matrix<Real>> gradW,
    vector<Matrix<Real>> lhs,
    const vector<int>& num_negs,
    Real rate_lhs,
    const vector<Real>& rate_rhsP,
    const vector<vector<Real>>& nRate) {

  using namespace boost::numeric::ublas;
  auto cols = args_->dim;

  typedef
    std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)>
    UpdateFn;
  std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)> updatePlain =
    [&] (MatrixRow& dest,
         const MatrixRow& src,
         Real rate,
         Real weight,
         std::vector<Real>& adagradWeight,
         int32_t idx) {
    dest -= (rate * src);
  };
  std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)> updateAdagrad =
    [&] (MatrixRow& dest,
         const MatrixRow& src,
         Real rate,
         Real weight,
         std::vector<Real>& adagradWeight,
         int32_t idx) {
    assert(idx < adagradWeight.size());
    adagradWeight[idx] += weight / cols;
    rate /= sqrt(adagradWeight[idx] + 1e-6);
    updatePlain(dest, src, rate, weight, adagradWeight, idx);
  };

  UpdateFn* update = args_->adagrad ?
    (UpdateFn*)(&updateAdagrad) : (UpdateFn*)(&updatePlain);

  auto batch_sz = batch_exs.size();
  std::vector<Real> n1(batch_sz, 0.0);
  std::vector<Real> n2(batch_sz, 0.0);
  if (args_->adagrad) {
    for (unsigned int i = 0; i < batch_sz; i++) if (num_negs[i] > 0) {
      n1[i] = dot(gradW[i], gradW[i]);
      n2[i] = dot(lhs[i], lhs[i]);
    }
  }
  // Update input items.
  // Update positive example.
  for (unsigned int i = 0; i < batch_sz; i++) if (num_negs[i] > 0) {
    const auto& items = batch_exs[i].LHSTokens;
    const auto& labels = batch_exs[i].RHSTokens;
    for (auto w : items) {
      auto row = LHSEmbeddings_->row(index(w));
      (*update)(row, gradW[i], rate_lhs * weight(w), n1[i], LHSUpdates_, index(w));
    }
    for (auto la : labels) {
      auto row = RHSEmbeddings_->row(index(la));
      (*update)(row, lhs[i], rate_rhsP[i] * weight(la), n2[i], RHSUpdates_, index(la));
    }
  }

  // Update negative example
  for (unsigned int j = 0; j < batch_negLabels.size(); j++) {
    for (unsigned int i = 0; i < batch_sz; i++) if (fabs(nRate[i][j]) > 1e-8) {
      for (auto la : batch_negLabels[j]) {
        auto row = RHSEmbeddings_->row(index(la));
        (*update)(row, lhs[i], nRate[i][j] * weight(la), n2[i], RHSUpdates_, index(la));
      }
    }
  }
}