void DoWriteWordAndClassInfo()

in Source/ActionsLib/OtherActions.cpp [245:499]


void DoWriteWordAndClassInfo(const ConfigParameters& config)
{
    size_t vocabSize = config(L"vocabSize");
    int nbrCls = config(L"nbrClass", "0"); // TODO: why int and not size_t?
    int cutoff = config(L"cutoff", "1");

    string inputFile = config(L"inputFile"); // training text file without <unk>
    string outputMappingFile = config(L"outputMappingFile", ""); // if specified then write a regular mapping file
    string outputVocabFile   = config(L"outputVocabFile");
    string outputWord2Cls  = nbrCls > 0 ? config(L"outputWord2Cls") : string();
    string outputCls2Index = nbrCls > 0 ? config(L"outputCls2Index") : string();

    string unkWord       = config(L"unk", "<unk>");
    string beginSequence = config(L"beginSequence", "");
    string endSequence   = config(L"endSequence",   "");
    // legacy: Old version hard-coded "</s>" for ^^ both of these.
    //         For a while, do not fall back to defaults but rather have users fix their scripts.
    if (beginSequence.empty() || endSequence.empty())
        InvalidArgument("Please specify parameters 'beginSequence' and 'endSequence'.");

    if (!outputMappingFile.empty())
        cerr << "Mapping file       --> " << outputMappingFile << endl;
    cerr     << "Vocabulary file    --> " << outputVocabFile   << endl;
    if (nbrCls > 0)
    {
        cerr << "Word-to-class map  --> " << outputWord2Cls  << endl;
        cerr << "Class-to-index map --> " << outputCls2Index << endl;
    }
    cerr << endl;

    // check whether we are already up-to-date
    bool makeMode = config(L"makeMode", true);
    if (makeMode)
    {
        bool done = msra::files::fuptodate(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputVocabFile), Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(inputFile), /*inputRequired=*/false);
        if (nbrCls > 0)
        {
            done &= msra::files::fuptodate(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputWord2Cls), Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(inputFile), /*inputRequired=*/false);
            done &= msra::files::fuptodate(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputCls2Index), Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(inputFile), /*inputRequired=*/false);
        }
        if (done)
        {
            cerr << "All output files up to date.\n";
            return;
        }
    }

    Matrix<ElemType> wrd2cls(CPUDEVICE);
    Matrix<ElemType> cls2idx(CPUDEVICE);

    ifstream fp(inputFile.c_str()); // TODO: use class File, as to support pipes
    if (!fp)
        RuntimeError("Failed to open input file: %s", inputFile.c_str());
    cerr << "Reading input file inputFile: " << inputFile << endl;

    if (nbrCls > 0)
        cls2idx.Resize(nbrCls, 1);

    unordered_map<string, double> v_count;

    // process input line by line
    string str;
    vector<string> vstr;
    long long prevClsIdx = -1;
    string token;
    const string beginSequencePattern = beginSequence + " ";
    const string endSequencePattern   = " " + endSequence;
    while (getline(fp, str))
    {
        str.erase(0, str.find_first_not_of(' ')); // prefixing spaces
        str.erase(str.find_last_not_of(' ') + 1); // surfixing spaces

        if (!beginSequence.empty() && str.find(beginSequencePattern) == str.npos)
            str = beginSequencePattern + str;

        if (!endSequence.empty() && str.find(endSequencePattern) == str.npos)
            str = str + endSequencePattern;

        vstr = msra::strfun::split(str, "\t ");
        // This loop used to start with 1, assuming begin and end symbol are the same.
        // If they are not, I am now counting them both. No idea whether that is correct w.r.t. the class algorithm.
        bool startWith1 = !beginSequence.empty() && beginSequence == endSequence;
        for (size_t i = startWith1 ? 1 : 0; i < vstr.size(); i++)
            v_count[vstr[i]]++;
    }
    fp.close();

    cerr << "Vocabulary size " << v_count.size() << ".\n";

    vector<string> m_words;
    set<string> m_remained_words;
    unordered_map<string, size_t> m_index;

    vector<double> m_count;
    vector<int> m_class; // class index of each word

    size_t wordCountLessCutoff = v_count.size();
    if (cutoff > 0)
        for (const auto& iter : v_count)
        {
            if (iter.second <= cutoff)
                wordCountLessCutoff--;
        }
    if (wordCountLessCutoff <= 0)
        RuntimeError("No word remained after cutoff with threshold %d.", (int)cutoff);

    if (vocabSize > wordCountLessCutoff)
    {
        cerr << "Warning: actual vocabulary size is less than required." << endl;
        cerr << "\t\tRequired vocabulary size:" << vocabSize << endl;
        cerr << "\t\tActual vocabulary size:" << v_count.size() << endl;
        cerr << "\t\tActual vocabulary size after cutoff:" << wordCountLessCutoff << endl;
        cerr << "\t\tWe will change to actual vocabulary size: " << wordCountLessCutoff << endl;
        vocabSize = wordCountLessCutoff;
    }

    if (nbrCls > 0)
    {
        // form classes
        // Implements an algorithm by Mikolov --TODO: get the reference
        wrd2cls.Resize(vocabSize, 1);

        typedef pair<string, double> stringdouble;
        unordered_map<string, double> removed; // note: std::map is supposedly faster
        double unkCount = 0; // TODO: why double?
        size_t size = 0;
        size_t actual_vocab_size = vocabSize - 1;
        priority_queue<stringdouble, vector<stringdouble>, compare_second<stringdouble>>
            q(compare_second<stringdouble>(), vector<stringdouble>(v_count.begin(), v_count.end()));
        while (size < actual_vocab_size && !q.empty()) // ==for (q=...; cond; q.pop())
        {
            size++;
            string word = q.top().first;
            double freq = q.top().second; // TODO: why double?
            if (word == unkWord)
            {
                unkCount += freq;
                actual_vocab_size++;
            }
            removed[q.top().first] = q.top().second;
            q.pop();
        }
        while (!q.empty())
        {
            unkCount += q.top().second;
            q.pop();
        }
        removed[unkWord] = unkCount;
        m_count.resize(removed.size());
        double total = 0;
        double dd = 0;
        if (nbrCls > 0)
        {
            for (const auto& iter : removed)
                total += iter.second;

            for (const auto& iter : removed)
                dd += sqrt(iter.second / total);
        }

        double df = 0;
        size_t class_id = 0;
        m_class.resize(removed.size());

        priority_queue<stringdouble, vector<stringdouble>, compare_second<stringdouble>>
            p(compare_second<stringdouble>(), vector<stringdouble>(removed.begin(), removed.end()));
        while (!p.empty())
        {
            string word = p.top().first;
            double freq = p.top().second;
            if (nbrCls > 0)
            {
                df += sqrt(freq / total) / dd;
                if (df > 1)
                    df = 1;

                if (df > 1.0 * (class_id + 1) / nbrCls && class_id < nbrCls)
                    class_id++;
            }

            size_t wid = m_words.size();
            bool inserted = m_index.insert(make_pair(word, wid)).second;
            if (inserted)
                m_words.push_back(word);

            m_count[wid] = freq;
            if (nbrCls > 0)
                m_class[wid] = class_id;
            p.pop();
        }
        assert(m_words.size() == m_index.size() && m_words.size() == m_class.size());
    }
    else // no classes
    {
        for (let& iter : v_count)
            m_words.push_back(iter.first);
        sort(m_words.begin(), m_words.end());
        m_count.resize(m_words.size());
        for (size_t i = 0; i < m_words.size(); i++)
            m_count[i] = v_count.find(m_words[i])->second;
    }

    assert(m_words.size() == m_count.size());

    // write the files
    if (!outputMappingFile.empty())
    {
        msra::files::make_intermediate_dirs(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputMappingFile));
        ofstream ofmapping(outputMappingFile.c_str());
        for (let& word : m_words)
            ofmapping << word << endl;
        ofmapping.close();
        cerr << "Created label-mapping file with " << v_count.size() << " entries.\n";
    }

    msra::files::make_intermediate_dirs(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputVocabFile));
    ofstream ofvocab(outputVocabFile.c_str());
    for (size_t i = 0; i < m_words.size(); i++)
    {
        if (nbrCls > 0)
            wrd2cls(i, 0) = (ElemType) m_class[i];
        long long clsIdx = nbrCls > 0 ? m_class[i] : 0;
        if (nbrCls > 0 && clsIdx != prevClsIdx)
        {
            cls2idx(clsIdx, 0) = (ElemType) i; // the left boundary of clsIdx
            prevClsIdx = m_class[i];
        }
        ofvocab << "     " << i << "\t     " << m_count[i] << "\t" << m_words[i] << "\t" << clsIdx << endl;
    }
    ofvocab.close();
    cerr << "Created vocabulary file with " << v_count.size() << " entries.\n";

    if (nbrCls > 0)
    {
        // write the outputs
        // TODO: use safe-save, i.e. write to temp name and rename at the end
        msra::files::make_intermediate_dirs(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputWord2Cls));
        ofstream owfp(outputWord2Cls.c_str());
        if (!owfp)
            RuntimeError("Failed to write to %s", outputWord2Cls.c_str());
        for (size_t r = 0; r < wrd2cls.GetNumRows(); r++)
            owfp << (int) wrd2cls(r, 0) << endl;
        owfp.close();
        cerr << "Created word-to-class map with " << wrd2cls.GetNumRows() << " entries.\n";

        msra::files::make_intermediate_dirs(Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(outputCls2Index));
        ofstream ocfp(outputCls2Index.c_str());
        if (!ocfp)
            RuntimeError("Failed to write to %s", outputCls2Index.c_str());
        for (size_t r = 0; r < cls2idx.GetNumRows(); r++)
            ocfp << (int) cls2idx(r, 0) << endl;
        ocfp.close();
        cerr << "Created class-to-index map with " << cls2idx.GetNumRows() << " entries.\n";
    }
}