in src/model.cpp [563:683]
float EmbedModel::trainNLLBatch(
shared_ptr<InternDataHandler> data,
const vector<ParseResults>& batch_exs,
int32_t negSearchLimit,
Real rate0,
bool trainWord) {
auto batch_sz = batch_exs.size();
std::vector<Matrix<Real>> lhs(batch_sz), rhsP(batch_sz), rhsN(negSearchLimit);
using namespace boost::numeric::ublas;
for (int i = 0; i < batch_sz; i++) {
const auto& items = batch_exs[i].LHSTokens;
const auto& labels = batch_exs[i].RHSTokens;
projectLHS(items, lhs[i]);
check(lhs[i]);
projectRHS(labels, rhsP[i]);
check(rhsP[i]);
}
std::vector<std::vector<Real>> prob(batch_sz);
std::vector<std::vector<Base>> batch_negLabels;
std::vector<Matrix<Real>> gradW(batch_sz);
std::vector<Real> loss(batch_sz);
std::vector<std::vector<Real>> nRate(batch_sz);
std::vector<int> num_negs(batch_sz, 0);
std::vector<Real> labelRate(batch_sz);
Real total_loss = 0.0;
for (int i = 0; i < negSearchLimit; i++) {
std::vector<Base> negLabels;
if (trainWord) {
data->getRandomWord(negLabels);
} else {
data->getRandomRHS(negLabels);
}
projectRHS(negLabels, rhsN[i]);
check(rhsN[i]);
batch_negLabels.push_back(negLabels);
}
for (int i = 0; i < batch_sz; i++) {
nRate[i].resize(negSearchLimit);
std::vector<int> index;
index.clear();
int cls_cnt = 1;
prob[i].clear();
prob[i].push_back(dot(lhs[i], rhsP[i]));
Real max = prob[i][0];
for (int j = 0; j < negSearchLimit; j++) {
nRate[i][j] = 0.0;
if (batch_negLabels[j] == batch_exs[i].RHSTokens) {
continue;
}
prob[i].push_back(dot(lhs[i], rhsN[j]));
max = (std::max)(prob[i][0], prob[i][cls_cnt]);
index.push_back(j);
cls_cnt += 1;
}
loss[i] = 0.0;
// skip, failed to find any negatives
if (cls_cnt == 1) {
continue;
}
num_negs[i] = cls_cnt - 1;
Real base = 0;
for (int j = 0; j < cls_cnt; j++) {
prob[i][j] = exp(prob[i][j] - max);
base += prob[i][j];
}
// normalize probabilities
for (int j = 0; j < cls_cnt; j++) {
prob[i][j] /= base;
}
loss[i] = -log(prob[i][0]);
total_loss += loss[i];
// Let w be the average of the words in the post, t+ be the
// positive example (the tag the post has) and t- be the average
// of the negative examples (the tags we searched for with submarginal
// separation above).
// Our error E is:
//
// E = - log P(t+)
//
// Where P(t) = exp(dot(w, t)) / (\sum_{t'} exp(dot(w, t')))
//
// Differentiating term-by-term we get:
//
// dE / dw = t+ (P(t+) - 1)
// dE / dt+ = w (P(t+) - 1)
// dE / dt- = w P(t-)
gradW[i] = rhsP[i];
gradW[i].matrix *= (prob[i][0] - 1);
for (int j = 1; j < cls_cnt; j++) {
auto inj = index[j - 1];
gradW[i].add(rhsN[inj], prob[i][j]);
nRate[i][inj] = prob[i][j] * rate0;
}
labelRate[i] = (prob[i][0] - 1) * rate0;
}
backward(
batch_exs, batch_negLabels,
gradW, lhs, num_negs,
rate0, labelRate, nRate);
return total_loss;
}