float EmbedModel::trainNLLBatch()

in src/model.cpp [563:683]


float EmbedModel::trainNLLBatch(
    shared_ptr<InternDataHandler> data,
    const vector<ParseResults>& batch_exs,
    int32_t negSearchLimit,
    Real rate0,
    bool trainWord) {

  auto batch_sz = batch_exs.size();
  std::vector<Matrix<Real>> lhs(batch_sz), rhsP(batch_sz), rhsN(negSearchLimit);

  using namespace boost::numeric::ublas;

  for (int i = 0; i < batch_sz; i++) {
    const auto& items = batch_exs[i].LHSTokens;
    const auto& labels = batch_exs[i].RHSTokens;
    projectLHS(items, lhs[i]);
    check(lhs[i]);

    projectRHS(labels, rhsP[i]);
    check(rhsP[i]);
  }

  std::vector<std::vector<Real>> prob(batch_sz);
  std::vector<std::vector<Base>> batch_negLabels;
  std::vector<Matrix<Real>> gradW(batch_sz);
  std::vector<Real> loss(batch_sz);

  std::vector<std::vector<Real>> nRate(batch_sz);
  std::vector<int> num_negs(batch_sz, 0);
  std::vector<Real> labelRate(batch_sz);

  Real total_loss = 0.0;

  for (int i = 0; i < negSearchLimit; i++) {
    std::vector<Base> negLabels;
    if (trainWord) {
      data->getRandomWord(negLabels);
    } else {
      data->getRandomRHS(negLabels);
    }
    projectRHS(negLabels, rhsN[i]);
    check(rhsN[i]);
    batch_negLabels.push_back(negLabels);
  }

  for (int i = 0; i < batch_sz; i++) {
    nRate[i].resize(negSearchLimit);
    std::vector<int> index;
    index.clear();

    int cls_cnt = 1;
    prob[i].clear();
    prob[i].push_back(dot(lhs[i], rhsP[i]));
    Real max = prob[i][0];

    for (int j = 0; j < negSearchLimit; j++) {
      nRate[i][j] = 0.0;
      if (batch_negLabels[j] == batch_exs[i].RHSTokens) {
        continue;
      }
      prob[i].push_back(dot(lhs[i], rhsN[j]));
      max = (std::max)(prob[i][0], prob[i][cls_cnt]);
      index.push_back(j);
      cls_cnt += 1;
    }
    loss[i] = 0.0;

    // skip, failed to find any negatives
    if (cls_cnt == 1) {
      continue;
    }

    num_negs[i] = cls_cnt - 1;
    Real base = 0;
    for (int j = 0; j < cls_cnt; j++) {
      prob[i][j] = exp(prob[i][j] - max);
      base += prob[i][j];
    }

    // normalize probabilities
    for (int j = 0; j < cls_cnt; j++) {
      prob[i][j] /= base;
    }

    loss[i] = -log(prob[i][0]);
    total_loss += loss[i];

    // Let w be the average of the words in the post, t+ be the
    // positive example (the tag the post has) and t- be the average
    // of the negative examples (the tags we searched for with submarginal
    // separation above).
    // Our error E is:
    //
    //    E = - log P(t+)
    //
    // Where P(t) = exp(dot(w, t)) / (\sum_{t'} exp(dot(w, t')))
    //
    // Differentiating term-by-term we get:
    //
    //    dE / dw = t+ (P(t+) - 1)
    //    dE / dt+ = w (P(t+) - 1)
    //    dE / dt- = w P(t-)

    gradW[i] = rhsP[i];
    gradW[i].matrix *= (prob[i][0] - 1);

    for (int j = 1; j < cls_cnt; j++) {
      auto inj = index[j - 1];
      gradW[i].add(rhsN[inj], prob[i][j]);
      nRate[i][inj] = prob[i][j] * rate0;
    }
    labelRate[i] = (prob[i][0] - 1) * rate0;
  }

  backward(
      batch_exs, batch_negLabels,
      gradW, lhs, num_negs,
      rate0, labelRate, nRate);

  return total_loss;
}