in Source/Readers/Kaldi2Reader/msra_mgram.h [2379:2807]
void estimate(int startId, const std::vector<unsigned int> &minObs, vector<bool> dropWord)
{
if (!adaptBuffer.empty())
throw runtime_error("estimate: adaptation buffer not empty, call adapt(*,0) to flush buffer first");
// Establish w->id mapping -- mapping is identical (w=id) during estimation.
std::vector<int> w2id(map.maxid() + 1);
foreach_index (i, w2id)
w2id[i] = i;
// std::vector<int> w2id (map.identical_map());
// close down creation of new tokens, so we can random-access
map.created(w2id);
// ensure M reflects the actual order of read data
while (M > 0 && counts.size(M) == 0)
resize(M - 1);
for (int m = 1; m <= M; m++)
fprintf(stderr, "estimate: read %d %d-grams\n", counts.size(m), m);
// === Kneser-Ney smoothing
// This is a strange algorithm.
#if 1 // Kneser-Ney back-off
// It seems not to work for fourgram models (applied to the trigram).
// But if it is only applied to bigram and unigram, there is a gain
// from the fourgram. So we are not applying it to trigram and above.
// ... TODO: use a constant to define the maximum KN count level,
// and then do not allocate memory above that.
mgram_data<unsigned int> KNCounts; // [shifted m-gram] (*,v,w)
mgram_data<unsigned int> KNTotalCounts; // [shifted, shortened m-gram] (*,v,*)
if (M >= 2)
{
fprintf(stderr, "estimate: allocating Kneser-Ney counts...\n");
KNCounts.init(M - 1);
for (int m = 0; m <= M - 1; m++)
KNCounts.assign(m, counts.size(m), 0);
KNTotalCounts.init(M - 2);
for (int m = 0; m <= M - 2; m++)
KNTotalCounts.assign(m, counts.size(m), 0);
fprintf(stderr, "estimate: computing Kneser-Ney counts...\n");
// loop over all m-grams to determine KN counts
for (mgram_map::deep_iterator iter(map); iter; ++iter)
{
const mgram_map::key key = *iter;
if (key.order() < 2)
continue; // undefined for unigrams
const mgram_map::key key_w = key.pop_h();
const mgram_map::foundcoord c_w = map[key_w];
if (!c_w.valid_w())
throw runtime_error("estimate: invalid shortened KN m-gram");
KNCounts[c_w]++; // (u,v,w) -> count (*,v,w)
const mgram_map::key key_h = key_w.pop_w();
mgram_map::foundcoord c_h = map[key_h];
if (!c_h.valid_w())
throw runtime_error("estimate: invalid shortened KN history");
KNTotalCounts[c_h]++; // (u,v,w) -> count (*,v,w)
}
}
#else // regular back-off: just use regular counts instad
mgram_data<unsigned int> &KNCounts = counts;
mgram_data<unsigned int> &KNTotalCounts = counts;
// not 'const' so we can later clear() them... this is only for testng anyway
#endif
// === estimate "modified Kneser-Ney" discounting values
// after Chen and Goodman: An empirical study of smoothing techniques for
// language modeling, CUED TR-09-09 -- a rich resource about everything LM!
std::vector<double> d1(M + 1, 0.0);
std::vector<double> d2(M + 1, 0.0);
std::vector<double> d3(M + 1, 0.0);
fprintf(stderr, "estimate: discounting values:");
{
// actually estimate discounting values
std::vector<int> n1(M + 1, 0); // how many have count=1, 2, 3, 4
std::vector<int> n2(M + 1, 0);
std::vector<int> n3(M + 1, 0);
std::vector<int> n4(M + 1, 0);
for (mgram_map::deep_iterator iter(map); iter; ++iter)
{
int m = iter.order();
if (m == 0)
continue; // skip the zerogram
unsigned int count = counts[iter];
// Kneser-Ney smoothing can also be done for back-off weight computation
if (m < M && m < 3) // for comments see where we estimate the discounted probabilities
{ // ^^ seems not to work for 4-grams...
const mgram_map::key key = *iter; // needed to check for startId
assert(key.order() == m);
if (m < 2 || key.pop_w().back() != startId)
{
count = KNCounts[iter];
if (count == 0) // must exist
throw runtime_error("estimate: malformed data: back-off value not found (numerator)");
}
}
if (count == 1)
n1[m]++;
else if (count == 2)
n2[m]++;
else if (count == 3)
n3[m]++;
else if (count == 4)
n4[m]++;
}
for (int m = 1; m <= M; m++)
{
if (n1[m] == 0)
throw runtime_error(msra::strfun::strprintf("estimate: error estimating discounting values: n1[%d] == 0", m));
if (n2[m] == 0)
throw runtime_error(msra::strfun::strprintf("estimate: error estimating discounting values: n2[%d] == 0", m));
// if (n3[m] == 0) RuntimeError ("estimate: error estimating discounting values: n3[%d] == 0", m);
double Y = n1[m] / (n1[m] + 2.0 * n2[m]);
if (n3[m] == 0 || n4[m] == 0)
{
fprintf(stderr, "estimate: n3[%d] or n4[%d] is 0, falling back to unmodified discounting\n", m, m);
d1[m] = Y;
d2[m] = Y;
d3[m] = Y;
}
else
{
d1[m] = 1.0 - 2.0 * Y * n2[m] / n1[m];
d2[m] = 2.0 - 3.0 * Y * n3[m] / n2[m];
d3[m] = 3.0 - 4.0 * Y * n4[m] / n3[m];
}
// ... can these be negative??
fprintf(stderr, " (%.3f, %.3f, %.3f)", d1[m], d2[m], d3[m]);
}
fprintf(stderr, "\n");
}
// === threshold against minimum counts (set counts to 0)
// this is done to save memory, but it has no impact on the seen probabilities
// ...well, it does, as pruned mass get pushed to back-off distribution... ugh!
fprintf(stderr, "estimate: pruning against minimum counts...\n");
// prune unigrams first (unigram cut-off can be higher than m-gram cut-offs,
// as a means to decimate the vocabulary)
unsigned int minUniObs = minObs[0]; // minimum unigram count
int removedWords = 0;
for (mgram_map::iterator iter(map, 1); iter; ++iter)
{ // unigram pruning is special: may be higher than m-gram threshold
if (counts[iter] >= minUniObs)
continue;
int wid = *iter;
dropWord[wid] = true; // will throw out all related m-grams
removedWords++;
}
fprintf(stderr, "estimate: removing %d too rare vocabulary entries\n", removedWords);
// now prune m-grams against count cut-off
std::vector<int> numMGrams(M + 1, 0);
msra::basetypes::fixed_vector<mgram_map::coord> histCoord(M); // index of history mgram
for (int m = 1; m <= M; m++)
{
for (mgram_map::deep_iterator iter(map); iter; ++iter)
{
if (iter.order() != m)
continue;
bool prune = counts[iter] < minObs[m - 1]; // prune if count below minimum
// prune by vocabulary
const mgram_map::key key = *iter;
for (int k = 0; !prune && k < m; k++)
{
int wid = key[k];
prune |= dropWord[wid];
}
if (prune)
{
counts[iter] = 0; // pruned: this is how we remember it
continue;
}
// for remaining words, check whether the structure is still intact
if (m < M)
histCoord[m] = iter;
mgram_map::coord j = histCoord[m - 1]; // parent
if (counts[j] == 0)
RuntimeError("estimate: invalid pruning: a parent m-gram got pruned away");
// throw runtime_error ("estimate: invalid pruning: a parent m-gram got pruned away");
numMGrams[m]++;
}
}
for (int m = 1; m <= M; m++)
{
fprintf(stderr, "estimate: %d-grams after pruning: %d out of %d (%.1f%%)\n", m,
numMGrams[m], counts.size(m),
100.0 * numMGrams[m] / max(counts.size(m), 1));
}
// ensure M reflects the actual order of read data after pruning
while (M > 0 && numMGrams[M] == 0)
resize(M - 1); // will change M
// === compact memory after pruning
// naw... this is VERY tricky with the mgram_map architecture to keep all data in sync
// So for now we just skip those in all subsequent steps (i.e. we don't save memory)
// === estimate M-gram
fprintf(stderr, "estimate: estimating probabilities...\n");
// dimension the m-gram store
mgram_data<float> P(M); // [M+1][i] probabilities
for (int m = 1; m <= M; m++)
P.reserve(m, numMGrams[m]);
// compute discounted probabilities (uninterpolated except, later, for unigram)
// We estimate into a new map so that pruned items get removed.
// For large data sets, where strong pruning is used, there is a net
// memory gain from doing this (we gain if pruning cuts more than half).
mgram_map Pmap(M);
for (int m = 1; m <= M; m++)
Pmap.reserve(m, numMGrams[m]);
mgram_map::cache_t PmapCache; // used in map.create()
// m-grams
P.push_back(mgram_map::coord(), 0.0f); // will be updated later
for (int m = 1; m <= M; m++)
{
fprintf(stderr, "estimate: estimating %d %d-gram probabilities...\n", numMGrams[m], m);
// loop over all m-grams of level 'm'
msra::basetypes::fixed_vector<mgram_map::coord> histCoord(m);
for (mgram_map::deep_iterator iter(map, m); iter; ++iter)
{
if (iter.order() != m)
{
// a parent: remember how successors can find their history
// (files are nested like a tree)
histCoord[iter.order()] = iter;
continue;
}
const mgram_map::key key = *iter;
assert(key.order() == iter.order()); // (remove this check once verified)
// get history's count
const mgram_map::coord j = histCoord[m - 1]; // index of parent entry
double histCount = counts[j]; // parent count --before pruning
// double histCount = succCount[j]; // parent count --actuals after pruning
// estimate probability for this M-gram
unsigned int count = counts[iter];
// this is 0 for pruned entries
// count = numerator, histCount = denominator
// Kneser-Ney smoothing --replace all but the highest-order
// distribution with that strange Kneser-Ney smoothed distribution.
if (m < M && m < 3 && count > 0) // all non-pruned items except highest order
{ // ^^ seems not to work for 4-gram
// We use a normal distribution if the history is the sentence
// start, as there we fallback without back-off. [Thanks to
// Yining Chen for the tip.]
if (m < 2 || key.pop_w().back() != startId)
{
count = KNCounts[iter]; // (u,v,w) -> count (*,v,w)
if (count == 0) // must exist
RuntimeError("estimate: malformed data: back-off value not found (numerator)");
const mgram_map::key key_h = key.pop_w();
mgram_map::foundcoord c_h = map[key_h];
if (!c_h.valid_w())
throw runtime_error("estimate: invalid shortened KN history");
histCount = KNTotalCounts[c_h]; // (u,v,w) -> count (*,v,*)
if (histCount == 0) // must exist
RuntimeError("estimate: malformed data: back-off value not found (denominator)");
assert(histCount >= count);
}
}
// pruned case
if (count == 0) // this entry was pruned before
goto skippruned;
// <s> does not count as an event, as it is never emitted.
// For now we prune it, but later we put the unigram back with -99.0.
if (key.back() == startId)
{ // (u, v, <s>)
if (m > 1) // do not generate m-grams
goto skippruned;
count = 0; // unigram is kept in structure
}
else if (m == 1)
{ // unigram non-<s> events
histCount--; // do not count <s> in denominator either
// For non-unigrams, we don't need to care because m-gram
// histories of <s> always ends in </s>, and we never ask for such an m-gram
// ... TODO: actually, is subtracting 1 the right thing to do here?
// shouldn't we subtract the unigram count of <s> instead?
}
// Histories with any token before <s> are not valuable, and
// actually need to be removed for consistency with the above
// rule of removing m-grams predicting <s> (if we don't we may
// create orphan m-grams).
for (int k = 1; k < m - 1; k++)
{ // ^^ <s> at k=0 and k=m-1 is OK; anywhere else -> useless m-gram
if (key[k] == startId)
goto skippruned;
}
// estimate discounted probability
double dcount = count; // "modified Kneser-Ney" discounting
if (count >= 3)
dcount -= d3[m];
else if (count == 2)
dcount -= d2[m];
else if (count == 1)
dcount -= d1[m];
if (dcount < 0.0) // 0.0 itself is caused by <s>
throw runtime_error("estimate: negative discounted count value");
if (histCount == 0)
RuntimeError("estimate: unexpected 0 denominator");
double dP = dcount / histCount;
// and this is the discounted probability value
{
// Actually, 'key' uses a "mapped" word ids, while create()
// expects unmapped ones. However, we have established an
// identical mapping at the start of this function, such that
// we can be sure that key=unmapped key.
mgram_map::coord c = Pmap.create((mgram_map::unmapped_key) key, PmapCache);
P.push_back(c, (float) dP);
}
skippruned:; // m-gram was pruned
}
}
// the distributions are not normalized --discount mass is missing
fprintf(stderr, "estimate: freeing memory for counts...\n");
KNCounts.clear(); // free some memory
KNTotalCounts.clear();
// the only items used below are P and Pmap.
w2id.resize(Pmap.maxid() + 1);
foreach_index (i, w2id)
w2id[i] = i;
// std::vector<int> w2id (Pmap.identical_map());
Pmap.created(w2id); // finalize and establish mapping for read access
map.swap(Pmap); // install the new map in our m-gram
Pmap.clear(); // no longer using the old one
counts.clear(); // counts also no longer needed
// zerogram
int vocabSize = 0;
for (mgram_map::iterator iter(map, 1); iter; ++iter)
if (P[iter] > 0.0) // (note: this excludes <s> and all pruned items)
vocabSize++;
P[mgram_map::coord()] = (float) (1.0 / vocabSize); // zerogram probability
// interpolating the unigram with the zerogram
// This is necessary as there is no back-off path from the unigram
// except in the OOV case. I.e. probability mass that was discounted
// from the unigrams is lost. We fix it by using linear interpolation
// instead of strict discounting for the unigram distribution.
double unigramSum = 0.0;
for (mgram_map::iterator iter(map, 1); iter; ++iter)
unigramSum += P[iter];
double missingUnigramMass = 1.0 - unigramSum;
if (missingUnigramMass > 0.0)
{
float missingUnigramProb = (float) (missingUnigramMass * P[mgram_map::coord()]);
fprintf(stderr, "estimate: distributing missing unigram mass of %.2f to %d unigrams\n",
missingUnigramMass, vocabSize);
for (mgram_map::iterator iter(map, 1); iter; ++iter)
{
if (P[iter] == 0.0f)
continue; // pruned
P[iter] += missingUnigramProb; // add it in
}
}
// --- M-gram sections --back-off weights
fprintf(stderr, "estimate: determining back-off weights...\n");
computeBackoff(map, M, P, logB, false);
// now the LM is normalized assuming the ARPA back-off computation
// --- take logs and push estimated values into base CMGramLM structure
// take logs in place
for (int m = 0; m <= M; m++)
for (mgram_map::iterator iter(map, m); iter; ++iter)
P[iter] = logclip(P[iter]); // pruned entries go to logzero
P.swap(logP); // swap into base language model
// --- final housekeeping to account for idiosyncrasies of the ARPA format
// resurrect sentence-start symbol with log score -99
const mgram_map::foundcoord cs = map[mgram_map::key(&startId, 1)];
if (cs.valid_w())
logP[cs] = -99.0f * log(10.0f);
// update zerogram prob
// The zerogram will only be used in the OOV case--the non-OOV case has
// been accounted for above by interpolating with the unigram. Thus, we
// replace the zerogram by a value appropriate for an OOV word. We
// choose the minimum unigram prob. This value is not stored in the ARPA
// file, but instead recomputed when loading it. We also reset the
// corresponding back-off weight to 1.0 such that we actually get the
// desired OOV score.
updateOOVScore();
fprintf(stderr, "estimate: done");
for (int m = 1; m <= M; m++)
fprintf(stderr, ", %d %d-grams", logP.size(m), m);
fprintf(stderr, "\n");
}