in src/model.cpp [487:561]
void EmbedModel::backward(
const vector<ParseResults>& batch_exs,
const vector<vector<Base>>& batch_negLabels,
vector<Matrix<Real>> gradW,
vector<Matrix<Real>> lhs,
const vector<int>& num_negs,
Real rate_lhs,
const vector<Real>& rate_rhsP,
const vector<vector<Real>>& nRate) {
using namespace boost::numeric::ublas;
auto cols = args_->dim;
typedef
std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)>
UpdateFn;
std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)> updatePlain =
[&] (MatrixRow& dest,
const MatrixRow& src,
Real rate,
Real weight,
std::vector<Real>& adagradWeight,
int32_t idx) {
dest -= (rate * src);
};
std::function<void(MatrixRow&, const MatrixRow&, Real, Real, std::vector<Real>&, int32_t)> updateAdagrad =
[&] (MatrixRow& dest,
const MatrixRow& src,
Real rate,
Real weight,
std::vector<Real>& adagradWeight,
int32_t idx) {
assert(idx < adagradWeight.size());
adagradWeight[idx] += weight / cols;
rate /= sqrt(adagradWeight[idx] + 1e-6);
updatePlain(dest, src, rate, weight, adagradWeight, idx);
};
UpdateFn* update = args_->adagrad ?
(UpdateFn*)(&updateAdagrad) : (UpdateFn*)(&updatePlain);
auto batch_sz = batch_exs.size();
std::vector<Real> n1(batch_sz, 0.0);
std::vector<Real> n2(batch_sz, 0.0);
if (args_->adagrad) {
for (unsigned int i = 0; i < batch_sz; i++) if (num_negs[i] > 0) {
n1[i] = dot(gradW[i], gradW[i]);
n2[i] = dot(lhs[i], lhs[i]);
}
}
// Update input items.
// Update positive example.
for (unsigned int i = 0; i < batch_sz; i++) if (num_negs[i] > 0) {
const auto& items = batch_exs[i].LHSTokens;
const auto& labels = batch_exs[i].RHSTokens;
for (auto w : items) {
auto row = LHSEmbeddings_->row(index(w));
(*update)(row, gradW[i], rate_lhs * weight(w), n1[i], LHSUpdates_, index(w));
}
for (auto la : labels) {
auto row = RHSEmbeddings_->row(index(la));
(*update)(row, lhs[i], rate_rhsP[i] * weight(la), n2[i], RHSUpdates_, index(la));
}
}
// Update negative example
for (unsigned int j = 0; j < batch_negLabels.size(); j++) {
for (unsigned int i = 0; i < batch_sz; i++) if (fabs(nRate[i][j]) > 1e-8) {
for (auto la : batch_negLabels[j]) {
auto row = RHSEmbeddings_->row(index(la));
(*update)(row, lhs[i], nRate[i][j] * weight(la), n2[i], RHSUpdates_, index(la));
}
}
}
}