void HTKMLFReader::PrepareForTrainingOrTesting()

in Source/Readers/HTKMLFReader/HTKMLFReader.cpp [108:603]


void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType& readerConfig)
{
    vector<wstring> scriptpaths;
    vector<wstring> RootPathInScripts;
    wstring RootPathInLatticeTocs;
    vector<wstring> mlfpaths;
    vector<vector<wstring>> mlfpathsmulti;
    size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing
    vector<vector<wstring>> infilesmulti;
    size_t numFiles;
    wstring unigrampath(L"");

    size_t randomize = randomizeAuto;
    size_t iFeat, iLabel;
    iFeat = iLabel = 0;
    vector<wstring> statelistpaths;
    vector<size_t> numContextLeft;
    vector<size_t> numContextRight;
    size_t numExpandToUtt = 0;

    std::vector<std::wstring> featureNames;
    std::vector<std::wstring> labelNames;

    // for hmm and lattice
    std::vector<std::wstring> hmmNames;
    std::vector<std::wstring> latticeNames;
    GetDataNamesFromConfig(readerConfig, featureNames, labelNames, hmmNames, latticeNames);
    if (featureNames.size() + labelNames.size() <= 1)
    {
        InvalidArgument("network needs at least 1 input and 1 output specified!");
    }

    // load data for all real-valued inputs (features)
    foreach_index (i, featureNames)
    {
        const ConfigRecordType& thisFeature = readerConfig(featureNames[i]);
        m_featDims.push_back(thisFeature(L"dim"));

        bool expandToUtt = thisFeature(L"expandToUtterance", false); // should feature be processed as an ivector?
        m_expandToUtt.push_back(expandToUtt);
        if (expandToUtt)
            numExpandToUtt++;

        intargvector contextWindow = thisFeature(L"contextWindow", ConfigRecordType::Array(intargvector(vector<int>{1})));
        if (contextWindow.size() == 1) // symmetric
        {
            size_t windowFrames = contextWindow[0];
            if (windowFrames % 2 == 0)
                InvalidArgument("augmentationextent: neighbor expansion of input features to %d not symmetrical", (int) windowFrames);

            size_t context = windowFrames / 2; // extend each side by this
            numContextLeft.push_back(context);
            numContextRight.push_back(context);
        }
        else if (contextWindow.size() == 2) // left context, right context
        {
            numContextLeft.push_back(contextWindow[0]);
            numContextRight.push_back(contextWindow[1]);
        }
        else
        {
            InvalidArgument("contextFrames must have 1 or 2 values specified, found %d", (int) contextWindow.size());
        }

        if (expandToUtt && (numContextLeft[i] != 0 || numContextRight[i] != 0))
            RuntimeError("contextWindow expansion not permitted when expandToUtterance=true");

        // update m_featDims to reflect the total input dimension (featDim x contextWindow), not the native feature dimension
        // that is what the lower level feature readers expect
        m_featDims[i] = m_featDims[i] * (1 + numContextLeft[i] + numContextRight[i]);

        wstring type = thisFeature(L"type", L"real");
        if (EqualCI(type, L"real"))
        {
            m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
        }
        else
        {
            InvalidArgument("feature type must be 'real'");
        }

        m_featureNameToIdMap[featureNames[i]] = iFeat;
        wstring type2 = thisFeature(L"scpFile");
        scriptpaths.push_back(type2);
        RootPathInScripts.push_back(thisFeature(L"prefixPathInSCP", L""));
        m_featureNameToDimMap[featureNames[i]] = m_featDims[i];

        m_featuresBufferMultiIO.push_back(nullptr);
        m_featuresBufferAllocatedMultiIO.push_back(0);

        iFeat++;
    }

    foreach_index (i, labelNames)
    {
        const ConfigRecordType& thisLabel = readerConfig(labelNames[i]);
        if (thisLabel.Exists(L"labelDim"))
            m_labelDims.push_back(thisLabel(L"labelDim"));
        else if (thisLabel.Exists(L"dim"))
            m_labelDims.push_back(thisLabel(L"dim"));
        else
            InvalidArgument("labels must specify dim or labelDim");

        wstring type;
        if (thisLabel.Exists(L"labelType"))
            type = (const wstring&) thisLabel(L"labelType"); // let's deprecate this eventually and just use "type"...
        else
            type = (const wstring&) thisLabel(L"type", L"category"); // outputs should default to category

        if (EqualCI(type, L"category"))
            m_nameToTypeMap[labelNames[i]] = InputOutputTypes::category;
        else
            InvalidArgument("label type must be 'category'");

        statelistpaths.push_back(thisLabel(L"labelMappingFile", L""));

        m_labelNameToIdMap[labelNames[i]] = iLabel;
        m_labelNameToDimMap[labelNames[i]] = m_labelDims[i];
        mlfpaths.clear();
        if (thisLabel.ExistsCurrent(L"mlfFile"))
        {
            wstring type2 = thisLabel(L"mlfFile");
            mlfpaths.push_back(type2);
        }
        else
        {
            if (!thisLabel.ExistsCurrent(L"mlfFileList"))
            {
                InvalidArgument("Either mlfFile or mlfFileList must exist in HTKMLFReder");
            }

            wstring list = thisLabel(L"mlfFileList");
            for (msra::files::textreader r(list); r;)
            {
                mlfpaths.push_back(r.wgetline());
            }
        }
        mlfpathsmulti.push_back(mlfpaths);

        m_labelsBufferMultiIO.push_back(nullptr);
        m_labelsBufferAllocatedMultiIO.push_back(0);

        iLabel++;

        wstring labelToTargetMappingFile(thisLabel(L"labelToTargetMappingFile", L""));
        if (labelToTargetMappingFile != L"")
        {
            std::vector<std::vector<ElemType>> labelToTargetMap;
            m_convertLabelsToTargetsMultiIO.push_back(true);
            if (thisLabel.Exists(L"targetDim"))
            {
                m_labelNameToDimMap[labelNames[i]] = m_labelDims[i] = thisLabel(L"targetDim");
            }
            else
            {
                RuntimeError("output must specify targetDim if labelToTargetMappingFile specified!");
            }

            size_t targetDim = ReadLabelToTargetMappingFile(labelToTargetMappingFile, statelistpaths[i], labelToTargetMap);
            if (targetDim != m_labelDims[i])
                RuntimeError("mismatch between targetDim and dim found in labelToTargetMappingFile");
            m_labelToTargetMapMultiIO.push_back(labelToTargetMap);
        }
        else
        {
            m_convertLabelsToTargetsMultiIO.push_back(false);
            m_labelToTargetMapMultiIO.push_back(std::vector<std::vector<ElemType>>());
        }
    }

    // get lattice toc file names
    std::pair<std::vector<wstring>, std::vector<wstring>> latticetocs;
    foreach_index (i, latticeNames) // only support one set of lattice now
    {
        const ConfigRecordType& thisLattice = readerConfig(latticeNames[i]);

        vector<wstring> paths;
        expand_wildcards(thisLattice(L"denLatTocFile"), paths);
        latticetocs.second.insert(latticetocs.second.end(), paths.begin(), paths.end());

        if (thisLattice.Exists(L"numLatTocFile"))
        {
            paths.clear();
            expand_wildcards(thisLattice(L"numLatTocFile"), paths);
            latticetocs.first.insert(latticetocs.first.end(), paths.begin(), paths.end());
        }
        RootPathInLatticeTocs = (wstring) thisLattice(L"prefixPathInToc", L"");
    }

    // get HMM related file names
    vector<wstring> cdphonetyingpaths, transPspaths;
    foreach_index (i, hmmNames)
    {
        const ConfigRecordType& thisHMM = readerConfig(hmmNames[i]);

        wstring type2 = thisHMM(L"phoneFile");
        cdphonetyingpaths.push_back(type2);
        transPspaths.push_back(thisHMM(L"transPFile", L""));
    }

    // mmf files
    // only support one set now
    if (cdphonetyingpaths.size() > 0 && statelistpaths.size() > 0 && transPspaths.size() > 0)
        m_hset.loadfromfile(cdphonetyingpaths[0], statelistpaths[0], transPspaths[0]);

    if (iFeat != scriptpaths.size() || iLabel != mlfpathsmulti.size())
        RuntimeError("# of inputs files vs. # of inputs or # of output files vs # of outputs inconsistent");

    if (iFeat == numExpandToUtt)
        RuntimeError("At least one feature stream must be frame-based, not utterance-based");

    if (m_expandToUtt[0]) // first feature stream is ivector type - that will mess up lower level feature reader
        RuntimeError("The first feature stream in the file must be frame-based not utterance based. Please reorder the feature blocks of your config appropriately");

    if (readerConfig.Exists(L"randomize"))
    {
        wstring randomizeString = readerConfig.CanBeString(L"randomize") ? readerConfig(L"randomize") : wstring();
        if      (EqualCI(randomizeString, L"none")) randomize = randomizeNone;
        else if (EqualCI(randomizeString, L"auto")) randomize = randomizeAuto;
        else                                        randomize = readerConfig(L"randomize"); // TODO: could this not just be randomizeString?
    }

    m_frameMode = readerConfig(L"frameMode", true);
    m_verbosity = readerConfig(L"verbosity", 0);

    if (m_frameMode && m_truncated)
    {
        InvalidArgument("'Truncated' cannot be 'true' in frameMode (i.e. when 'frameMode' is 'true')");
    }

    // determine if we partial minibatches are desired
    wstring minibatchMode(readerConfig(L"minibatchMode", L"partial"));
    m_partialMinibatch = EqualCI(minibatchMode, L"partial");

    // get the read method, defaults to "blockRandomize" other option is "rollingWindow"
    wstring readMethod(readerConfig(L"readMethod", L"blockRandomize"));

    if (readMethod == L"blockRandomize" && randomize == randomizeNone)
        InvalidArgument("'randomize' cannot be 'none' when 'readMethod' is 'blockRandomize'.");

    if (readMethod == L"rollingWindow" && numExpandToUtt>0)
        RuntimeError("rollingWindow reader does not support expandToUtt. Change to blockRandomize.");

    // read all input files (from multiple inputs)
    // TO DO: check for consistency (same number of files in each script file)
    numFiles = 0;
    foreach_index (i, scriptpaths)
    {
        vector<wstring> filelist;
        std::wstring scriptpath = scriptpaths[i];
        fprintf(stderr, "reading script file %ls ...", scriptpath.c_str());
        size_t n = 0;
        for (msra::files::textreader reader(scriptpath); reader && filelist.size() <= firstfilesonly /*optimization*/;)
        {
            filelist.push_back(reader.wgetline());
            n++;
        }

        fprintf(stderr, " %lu entries\n", (unsigned long)n);

        if (i == 0)
            numFiles = n;
        else if (n != numFiles)
            RuntimeError("number of files in each scriptfile inconsistent (%d vs. %d)", (int) numFiles, (int) n);

        // post processing file list :
        //  - if users specified PrefixPath, add the prefix to each of path in filelist
        //  - else do the dotdotdot expansion if necessary
        wstring rootpath = RootPathInScripts[i];
        if (!rootpath.empty()) // use has specified a path prefix for this  feature
        {
            // first make slash consistent (sorry for linux users:this is not necessary for you)
            std::replace(rootpath.begin(), rootpath.end(), L'\\', L'/');

            // second, remove trailing slash if there is any
            // TODO: when gcc -v is 4.9 or greater, this should be: std::regex_replace(rootpath, L"\\/+$", wstring());
            int stringPos = 0;
            for (stringPos = (int) (rootpath.length() - 1); stringPos >= 0; stringPos--) 
            {
                if (rootpath[stringPos] != L'/')
                {
                    break;
                }
            }
            rootpath = rootpath.substr(0, stringPos + 1);

            // third, join the rootpath with each entry in filelist
            if (!rootpath.empty())
            {
                for (wstring& path : filelist)
                {
                    if (path.find_first_of(L'=') != wstring::npos)
                    {
                        vector<wstring> strarr = msra::strfun::split(path, L"=");
#ifdef WIN32
                        replace(strarr[1].begin(), strarr[1].end(), L'\\', L'/');
#endif

                        path = strarr[0] + L"=" + rootpath + L"/" + strarr[1];
                    }
                    else
                    {
#ifdef WIN32
                        replace(path.begin(), path.end(), L'\\', L'/');
#endif
                        path = rootpath + L"/" + path;
                    }
                }
            }
        }
        else
        {
            /*
            do "..." expansion if SCP uses relative path names
            "..." in the SCP means full path is the same as the SCP file
            for example, if scp file is "//aaa/bbb/ccc/ddd.scp"
            and contains entry like
            .../file1.feat
            .../file2.feat
            etc.
            the features will be read from
            // aaa/bbb/ccc/file1.feat
            // aaa/bbb/ccc/file2.feat
            etc.
            This works well if you store the scp file with the features but
            do not want different scp files everytime you move or create new features
            */
            wstring scpdircached;
            for (auto& entry : filelist)
                ExpandDotDotDot(entry, scriptpath, scpdircached);
        }

        infilesmulti.push_back(std::move(filelist));
    }

    if (readerConfig.Exists(L"unigram"))
        unigrampath = (const wstring&) readerConfig(L"unigram");

    // load a unigram if needed (this is used for MMI training)
    msra::lm::CSymbolSet unigramsymbols;
    std::unique_ptr<msra::lm::CMGramLM> unigram;
    size_t silencewordid = SIZE_MAX;
    size_t startwordid = SIZE_MAX;
    size_t endwordid = SIZE_MAX;
    if (unigrampath != L"")
    {
        unigram.reset(new msra::lm::CMGramLM());
        unigram->read(unigrampath, unigramsymbols, false /*filterVocabulary--false will build the symbol map*/, 1 /*maxM--unigram only*/);
        silencewordid = unigramsymbols["!silence"]; // give this an id (even if not in the LM vocabulary)
        startwordid = unigramsymbols["<s>"];
        endwordid = unigramsymbols["</s>"];
    }

    if (!unigram && latticetocs.second.size() > 0)
        fprintf(stderr, "trainlayer: OOV-exclusion code enabled, but no unigram specified to derive the word set from, so you won't get OOV exclusion\n");

    // currently assumes all mlfs will have same root name (key)
    set<wstring> restrictmlftokeys; // restrict MLF reader to these files--will make stuff much faster without having to use shortened input files
    if (infilesmulti[0].size() <= 100)
    {
        foreach_index (i, infilesmulti[0])
        {
            msra::asr::htkfeatreader::parsedpath ppath(infilesmulti[0][i]);
            const wstring ppathStr = (wstring) ppath;

            // delete extension (or not if none) 
            // TODO: when gcc -v is 4.9 or greater, this should be: regex_replace((wstring)ppath, wregex(L"\\.[^\\.\\\\/:]*$"), wstring()); 
            int stringPos = 0;
            for (stringPos = (int) ppathStr.length() - 1; stringPos >= 0; stringPos--) 
            {
                if (ppathStr[stringPos] == L'.' || ppathStr[stringPos] == L'\\' || ppathStr[stringPos] == L'/' || ppathStr[stringPos] == L':')
                {
                    break;
                }
            }

            if (ppathStr[stringPos] == L'.') {
                restrictmlftokeys.insert(ppathStr.substr(0, stringPos));
            }
            else 
            {
                restrictmlftokeys.insert(ppathStr);
            }
        }
    }
    // get labels

    // if (readerConfig.Exists(L"statelist"))
    //    statelistpath = readerConfig(L"statelist");

    double htktimetoframe = 100000.0; // default is 10ms
    // std::vector<msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>> labelsmulti;
    std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
    // std::vector<std::wstring> pagepath;
    foreach_index (i, mlfpathsmulti)
    {
        const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
        msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
        labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
        // get the temp file name for the page file

        // Make sure 'msra::asr::htkmlfreader' type has a move constructor
        static_assert(std::is_move_constructible<msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>>::value,
                      "Type 'msra::asr::htkmlfreader' should be move constructible!");

        labelsmulti.push_back(std::move(labels));
    }

    if (EqualCI(readMethod, L"blockRandomize"))
    {
        // construct all the parameters we don't need, but need to be passed to the constructor...

        m_lattices.reset(new msra::dbn::latticesource(latticetocs, m_hset.getsymmap(), RootPathInLatticeTocs));
        m_lattices->setverbosity(m_verbosity);

        // now get the frame source. This has better randomization and doesn't create temp files
        bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false);
        m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
                                                                         numContextLeft, numContextRight, randomize, 
                                                                         *m_lattices, m_latticeMap, m_frameMode, 
                                                                         m_expandToUtt, m_maxUtteranceLength, m_truncated));
        m_frameSource->setverbosity(m_verbosity);
    }
    else if (EqualCI(readMethod, L"rollingWindow"))
    {
        std::wstring pageFilePath;
        std::vector<std::wstring> pagePaths;
        if (readerConfig.Exists(L"pageFilePath"))
        {
            pageFilePath = (const wstring&) readerConfig(L"pageFilePath");

#ifdef _WIN32
            // replace any '/' with '\' for compat with default path
            std::replace(pageFilePath.begin(), pageFilePath.end(), '/', '\\');

            // verify path exists
            DWORD attrib = GetFileAttributes(pageFilePath.c_str());
            if (attrib == INVALID_FILE_ATTRIBUTES || !(attrib & FILE_ATTRIBUTE_DIRECTORY))
                RuntimeError("pageFilePath does not exist");
#endif
#ifdef __unix__
            struct stat statbuf;
            if (stat(wtocharpath(pageFilePath).c_str(), &statbuf) == -1)
            {
                RuntimeError("pageFilePath does not exist");
            }
#endif
        }
        else // using default temporary path
        {
#ifdef _WIN32
            pageFilePath.reserve(MAX_PATH);
            GetTempPath(MAX_PATH, &pageFilePath[0]);
#endif
#ifdef __unix__
            pageFilePath = L"/tmp/temp.CNTK.XXXXXX";
#endif
        }

#ifdef _WIN32
        if (pageFilePath.size() > MAX_PATH - 14) // max length of input to GetTempFileName is MAX_PATH-14
            RuntimeError("pageFilePath must be less than %d characters", MAX_PATH - 14);
#else
        if (pageFilePath.size() > PATH_MAX - 14) // max length of input to GetTempFileName is PATH_MAX-14
            RuntimeError("pageFilePath must be less than %d characters", PATH_MAX - 14);
#endif
        foreach_index (i, infilesmulti)
        {
#ifdef _WIN32
            wchar_t tempFile[MAX_PATH];
            GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
            pagePaths.push_back(tempFile);
#endif
#ifdef __unix__
            char tempFile[PATH_MAX];
            strcpy(tempFile, Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(pageFilePath)).c_str());
            int fid = mkstemp(tempFile);
            unlink(tempFile);
            close(fid);
            pagePaths.push_back(GetWC(tempFile));
#endif
        }

        const bool mayhavenoframe = false;
        int addEnergy = 0;

        m_frameSource.reset(new msra::dbn::minibatchframesourcemulti(infilesmulti, labelsmulti, m_featDims, m_labelDims, 
                                                                     numContextLeft, numContextRight, randomize, 
                                                                     pagePaths, mayhavenoframe, addEnergy));
        m_frameSource->setverbosity(m_verbosity);
    }
    else
    {
        RuntimeError("readMethod must be 'rollingWindow' or 'blockRandomize'");
    }
}