void estimate()

in Source/Readers/Kaldi2Reader/msra_mgram.h [2379:2807]


    void estimate(int startId, const std::vector<unsigned int> &minObs, vector<bool> dropWord)
    {
        if (!adaptBuffer.empty())
            throw runtime_error("estimate: adaptation buffer not empty, call adapt(*,0) to flush buffer first");

        // Establish w->id mapping -- mapping is identical (w=id) during estimation.
        std::vector<int> w2id(map.maxid() + 1);
        foreach_index (i, w2id)
            w2id[i] = i;
        // std::vector<int> w2id (map.identical_map());

        // close down creation of new tokens, so we can random-access
        map.created(w2id);

        // ensure M reflects the actual order of read data
        while (M > 0 && counts.size(M) == 0)
            resize(M - 1);

        for (int m = 1; m <= M; m++)
            fprintf(stderr, "estimate: read %d %d-grams\n", counts.size(m), m);

// === Kneser-Ney smoothing
// This is a strange algorithm.

#if 1 // Kneser-Ney back-off
        // It seems not to work for fourgram models (applied to the trigram).
        // But if it is only applied to bigram and unigram, there is a gain
        // from the fourgram. So we are not applying it to trigram and above.
        // ... TODO: use a constant to define the maximum KN count level,
        // and then do not allocate memory above that.
        mgram_data<unsigned int> KNCounts;      // [shifted m-gram] (*,v,w)
        mgram_data<unsigned int> KNTotalCounts; // [shifted, shortened m-gram] (*,v,*)
        if (M >= 2)
        {
            fprintf(stderr, "estimate: allocating Kneser-Ney counts...\n");

            KNCounts.init(M - 1);
            for (int m = 0; m <= M - 1; m++)
                KNCounts.assign(m, counts.size(m), 0);
            KNTotalCounts.init(M - 2);
            for (int m = 0; m <= M - 2; m++)
                KNTotalCounts.assign(m, counts.size(m), 0);

            fprintf(stderr, "estimate: computing Kneser-Ney counts...\n");

            // loop over all m-grams to determine KN counts
            for (mgram_map::deep_iterator iter(map); iter; ++iter)
            {
                const mgram_map::key key = *iter;
                if (key.order() < 2)
                    continue; // undefined for unigrams
                const mgram_map::key key_w = key.pop_h();
                const mgram_map::foundcoord c_w = map[key_w];
                if (!c_w.valid_w())
                    throw runtime_error("estimate: invalid shortened KN m-gram");
                KNCounts[c_w]++; // (u,v,w) -> count (*,v,w)
                const mgram_map::key key_h = key_w.pop_w();
                mgram_map::foundcoord c_h = map[key_h];
                if (!c_h.valid_w())
                    throw runtime_error("estimate: invalid shortened KN history");
                KNTotalCounts[c_h]++; // (u,v,w) -> count (*,v,w)
            }
        }
#else // regular back-off: just use regular counts instad
        mgram_data<unsigned int> &KNCounts = counts;
        mgram_data<unsigned int> &KNTotalCounts = counts;
// not 'const' so we can later clear() them... this is only for testng anyway
#endif

        // === estimate "modified Kneser-Ney" discounting values
        // after Chen and Goodman: An empirical study of smoothing techniques for
        // language modeling, CUED TR-09-09 -- a rich resource about everything LM!

        std::vector<double> d1(M + 1, 0.0);
        std::vector<double> d2(M + 1, 0.0);
        std::vector<double> d3(M + 1, 0.0);
        fprintf(stderr, "estimate: discounting values:");

        {
            // actually estimate discounting values
            std::vector<int> n1(M + 1, 0); // how many have count=1, 2, 3, 4
            std::vector<int> n2(M + 1, 0);
            std::vector<int> n3(M + 1, 0);
            std::vector<int> n4(M + 1, 0);

            for (mgram_map::deep_iterator iter(map); iter; ++iter)
            {
                int m = iter.order();
                if (m == 0)
                    continue; // skip the zerogram

                unsigned int count = counts[iter];

                // Kneser-Ney smoothing can also be done for back-off weight computation
                if (m < M && m < 3)                   // for comments see where we estimate the discounted probabilities
                {                                     //    ^^ seems not to work for 4-grams...
                    const mgram_map::key key = *iter; // needed to check for startId
                    assert(key.order() == m);

                    if (m < 2 || key.pop_w().back() != startId)
                    {
                        count = KNCounts[iter];
                        if (count == 0) // must exist
                            throw runtime_error("estimate: malformed data: back-off value not found (numerator)");
                    }
                }

                if (count == 1)
                    n1[m]++;
                else if (count == 2)
                    n2[m]++;
                else if (count == 3)
                    n3[m]++;
                else if (count == 4)
                    n4[m]++;
            }

            for (int m = 1; m <= M; m++)
            {
                if (n1[m] == 0)
                    throw runtime_error(msra::strfun::strprintf("estimate: error estimating discounting values: n1[%d] == 0", m));
                if (n2[m] == 0)
                    throw runtime_error(msra::strfun::strprintf("estimate: error estimating discounting values: n2[%d] == 0", m));
                // if (n3[m] == 0) RuntimeError ("estimate: error estimating discounting values: n3[%d] == 0", m);
                double Y = n1[m] / (n1[m] + 2.0 * n2[m]);
                if (n3[m] == 0 || n4[m] == 0)
                {
                    fprintf(stderr, "estimate: n3[%d] or n4[%d] is 0, falling back to unmodified discounting\n", m, m);
                    d1[m] = Y;
                    d2[m] = Y;
                    d3[m] = Y;
                }
                else
                {
                    d1[m] = 1.0 - 2.0 * Y * n2[m] / n1[m];
                    d2[m] = 2.0 - 3.0 * Y * n3[m] / n2[m];
                    d3[m] = 3.0 - 4.0 * Y * n4[m] / n3[m];
                }
                // ... can these be negative??
                fprintf(stderr, " (%.3f, %.3f, %.3f)", d1[m], d2[m], d3[m]);
            }
            fprintf(stderr, "\n");
        }

        // === threshold against minimum counts (set counts to 0)
        // this is done to save memory, but it has no impact on the seen probabilities
        // ...well, it does, as pruned mass get pushed to back-off distribution... ugh!

        fprintf(stderr, "estimate: pruning against minimum counts...\n");

        // prune unigrams first (unigram cut-off can be higher than m-gram cut-offs,
        // as a means to decimate the vocabulary)

        unsigned int minUniObs = minObs[0]; // minimum unigram count
        int removedWords = 0;
        for (mgram_map::iterator iter(map, 1); iter; ++iter)
        { // unigram pruning is special: may be higher than m-gram threshold
            if (counts[iter] >= minUniObs)
                continue;
            int wid = *iter;
            dropWord[wid] = true; // will throw out all related m-grams
            removedWords++;
        }
        fprintf(stderr, "estimate: removing %d too rare vocabulary entries\n", removedWords);

        // now prune m-grams against count cut-off

        std::vector<int> numMGrams(M + 1, 0);
        msra::basetypes::fixed_vector<mgram_map::coord> histCoord(M); // index of history mgram
        for (int m = 1; m <= M; m++)
        {
            for (mgram_map::deep_iterator iter(map); iter; ++iter)
            {
                if (iter.order() != m)
                    continue;
                bool prune = counts[iter] < minObs[m - 1]; // prune if count below minimum
                // prune by vocabulary
                const mgram_map::key key = *iter;
                for (int k = 0; !prune && k < m; k++)
                {
                    int wid = key[k];
                    prune |= dropWord[wid];
                }
                if (prune)
                {
                    counts[iter] = 0; // pruned: this is how we remember it
                    continue;
                }
                // for remaining words, check whether the structure is still intact
                if (m < M)
                    histCoord[m] = iter;
                mgram_map::coord j = histCoord[m - 1]; // parent
                if (counts[j] == 0)
                    RuntimeError("estimate: invalid pruning: a parent m-gram got pruned away");
                // throw runtime_error ("estimate: invalid pruning: a parent m-gram got pruned away");
                numMGrams[m]++;
            }
        }

        for (int m = 1; m <= M; m++)
        {
            fprintf(stderr, "estimate: %d-grams after pruning: %d out of %d (%.1f%%)\n", m,
                    numMGrams[m], counts.size(m),
                    100.0 * numMGrams[m] / max(counts.size(m), 1));
        }

        // ensure M reflects the actual order of read data after pruning
        while (M > 0 && numMGrams[M] == 0)
            resize(M - 1); // will change M

        // === compact memory after pruning

        // naw... this is VERY tricky with the mgram_map architecture to keep all data in sync
        // So for now we just skip those in all subsequent steps (i.e. we don't save memory)

        // === estimate M-gram

        fprintf(stderr, "estimate: estimating probabilities...\n");

        // dimension the m-gram store
        mgram_data<float> P(M); // [M+1][i] probabilities
        for (int m = 1; m <= M; m++)
            P.reserve(m, numMGrams[m]);

        // compute discounted probabilities (uninterpolated except, later, for unigram)

        // We estimate into a new map so that pruned items get removed.
        // For large data sets, where strong pruning is used, there is a net
        // memory gain from doing this (we gain if pruning cuts more than half).
        mgram_map Pmap(M);
        for (int m = 1; m <= M; m++)
            Pmap.reserve(m, numMGrams[m]);
        mgram_map::cache_t PmapCache; // used in map.create()

        // m-grams
        P.push_back(mgram_map::coord(), 0.0f); // will be updated later
        for (int m = 1; m <= M; m++)
        {
            fprintf(stderr, "estimate: estimating %d %d-gram probabilities...\n", numMGrams[m], m);

            // loop over all m-grams of level 'm'
            msra::basetypes::fixed_vector<mgram_map::coord> histCoord(m);
            for (mgram_map::deep_iterator iter(map, m); iter; ++iter)
            {
                if (iter.order() != m)
                {
                    // a parent: remember how successors can find their history
                    // (files are nested like a tree)
                    histCoord[iter.order()] = iter;
                    continue;
                }

                const mgram_map::key key = *iter;
                assert(key.order() == iter.order()); // (remove this check once verified)

                // get history's count
                const mgram_map::coord j = histCoord[m - 1]; // index of parent entry
                double histCount = counts[j];                // parent count --before pruning
                // double histCount = succCount[j];        // parent count --actuals after pruning

                // estimate probability for this M-gram
                unsigned int count = counts[iter];
                // this is 0 for pruned entries

                // count = numerator, histCount = denominator

                // Kneser-Ney smoothing --replace all but the highest-order
                // distribution with that strange Kneser-Ney smoothed distribution.
                if (m < M && m < 3 && count > 0) // all non-pruned items except highest order
                {                                //    ^^ seems not to work for 4-gram
                    // We use a normal distribution if the history is the sentence
                    // start, as there we fallback without back-off. [Thanks to
                    // Yining Chen for the tip.]
                    if (m < 2 || key.pop_w().back() != startId)
                    {
                        count = KNCounts[iter]; // (u,v,w) -> count (*,v,w)
                        if (count == 0)         // must exist
                            RuntimeError("estimate: malformed data: back-off value not found (numerator)");

                        const mgram_map::key key_h = key.pop_w();
                        mgram_map::foundcoord c_h = map[key_h];
                        if (!c_h.valid_w())
                            throw runtime_error("estimate: invalid shortened KN history");
                        histCount = KNTotalCounts[c_h]; // (u,v,w) -> count (*,v,*)
                        if (histCount == 0)             // must exist
                            RuntimeError("estimate: malformed data: back-off value not found (denominator)");
                        assert(histCount >= count);
                    }
                }

                // pruned case
                if (count == 0) // this entry was pruned before
                    goto skippruned;

                // <s> does not count as an event, as it is never emitted.
                // For now we prune it, but later we put the unigram back with -99.0.
                if (key.back() == startId)
                {              // (u, v, <s>)
                    if (m > 1) // do not generate m-grams
                        goto skippruned;
                    count = 0; // unigram is kept in structure
                }
                else if (m == 1)
                {                // unigram non-<s> events
                    histCount--; // do not count <s> in denominator either
                    // For non-unigrams, we don't need to care because m-gram
                    // histories of <s> always ends in </s>, and we never ask for such an m-gram
                    // ... TODO: actually, is subtracting 1 the right thing to do here?
                    // shouldn't we subtract the unigram count of <s> instead?
                }

                // Histories with any token before <s> are not valuable, and
                // actually need to be removed for consistency with the above
                // rule of removing m-grams predicting <s> (if we don't we may
                // create orphan m-grams).
                for (int k = 1; k < m - 1; k++)
                { // ^^ <s> at k=0 and k=m-1 is OK; anywhere else -> useless m-gram
                    if (key[k] == startId)
                        goto skippruned;
                }

                // estimate discounted probability
                double dcount = count; // "modified Kneser-Ney" discounting
                if (count >= 3)
                    dcount -= d3[m];
                else if (count == 2)
                    dcount -= d2[m];
                else if (count == 1)
                    dcount -= d1[m];
                if (dcount < 0.0) // 0.0 itself is caused by <s>
                    throw runtime_error("estimate: negative discounted count value");

                if (histCount == 0)
                    RuntimeError("estimate: unexpected 0 denominator");
                double dP = dcount / histCount;
                // and this is the discounted probability value
                {
                    // Actually, 'key' uses a "mapped" word ids, while create()
                    // expects unmapped ones. However, we have established an
                    // identical mapping at the start of this function, such that
                    // we can be sure that key=unmapped key.
                    mgram_map::coord c = Pmap.create((mgram_map::unmapped_key) key, PmapCache);
                    P.push_back(c, (float) dP);
                }

            skippruned:; // m-gram was pruned
            }
        }
        // the distributions are not normalized --discount mass is missing
        fprintf(stderr, "estimate: freeing memory for counts...\n");
        KNCounts.clear(); // free some memory
        KNTotalCounts.clear();

        // the only items used below are P and Pmap.
        w2id.resize(Pmap.maxid() + 1);
        foreach_index (i, w2id)
            w2id[i] = i;
        // std::vector<int> w2id (Pmap.identical_map());
        Pmap.created(w2id); // finalize and establish mapping for read access
        map.swap(Pmap);     // install the new map in our m-gram
        Pmap.clear();       // no longer using the old one

        counts.clear(); // counts also no longer needed

        // zerogram
        int vocabSize = 0;
        for (mgram_map::iterator iter(map, 1); iter; ++iter)
            if (P[iter] > 0.0) // (note: this excludes <s> and all pruned items)
                vocabSize++;
        P[mgram_map::coord()] = (float) (1.0 / vocabSize); // zerogram probability

        // interpolating the unigram with the zerogram
        // This is necessary as there is no back-off path from the unigram
        // except in the OOV case. I.e. probability mass that was discounted
        // from the unigrams is lost. We fix it by using linear interpolation
        // instead of strict discounting for the unigram distribution.
        double unigramSum = 0.0;
        for (mgram_map::iterator iter(map, 1); iter; ++iter)
            unigramSum += P[iter];
        double missingUnigramMass = 1.0 - unigramSum;
        if (missingUnigramMass > 0.0)
        {
            float missingUnigramProb = (float) (missingUnigramMass * P[mgram_map::coord()]);
            fprintf(stderr, "estimate: distributing missing unigram mass of %.2f to %d unigrams\n",
                    missingUnigramMass, vocabSize);
            for (mgram_map::iterator iter(map, 1); iter; ++iter)
            {
                if (P[iter] == 0.0f)
                    continue;                  // pruned
                P[iter] += missingUnigramProb; // add it in
            }
        }

        // --- M-gram sections --back-off weights

        fprintf(stderr, "estimate: determining back-off weights...\n");
        computeBackoff(map, M, P, logB, false);
        // now the LM is normalized assuming the ARPA back-off computation

        // --- take logs and push estimated values into base CMGramLM structure

        // take logs in place
        for (int m = 0; m <= M; m++)
            for (mgram_map::iterator iter(map, m); iter; ++iter)
                P[iter] = logclip(P[iter]); // pruned entries go to logzero
        P.swap(logP);                       // swap into base language model

        // --- final housekeeping to account for idiosyncrasies of the ARPA format

        // resurrect sentence-start symbol with log score -99
        const mgram_map::foundcoord cs = map[mgram_map::key(&startId, 1)];
        if (cs.valid_w())
            logP[cs] = -99.0f * log(10.0f);

        // update zerogram prob
        // The zerogram will only be used in the OOV case--the non-OOV case has
        // been accounted for above by interpolating with the unigram. Thus, we
        // replace the zerogram by a value appropriate for an OOV word. We
        // choose the minimum unigram prob. This value is not stored in the ARPA
        // file, but instead recomputed when loading it. We also reset the
        // corresponding back-off weight to 1.0 such that we actually get the
        // desired OOV score.
        updateOOVScore();

        fprintf(stderr, "estimate: done");
        for (int m = 1; m <= M; m++)
            fprintf(stderr, ", %d %d-grams", logP.size(m), m);
        fprintf(stderr, "\n");
    }