void HTKMLFReader::PrepareForTrainingOrTesting()

in Source/Readers/Kaldi2Reader/HTKMLFReader.cpp [245:564]


void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType& readerConfig)
{
    // Loads files for sequence training.
    if (m_doSeqTrain)
    {
        PrepareForSequenceTraining(readerConfig);
    }

    // Variables related to multi-utterance.
    // m_featuresBufferMultiUtt:
    //     Holds pointers to the data trunk for each utterance.
    // m_featuresBufferAllocatedMultiUtt:
    //     Actual data stores here.
    m_featuresBufferMultiUtt.assign(m_numberOfuttsPerMinibatch, NULL);
    m_featuresBufferAllocatedMultiUtt.assign(m_numberOfuttsPerMinibatch, 0);
    m_labelsBufferMultiUtt.assign(m_numberOfuttsPerMinibatch, NULL);
    m_labelsBufferAllocatedMultiUtt.assign(m_numberOfuttsPerMinibatch, 0);

    // Gets a list of features and labels. Note that we assume feature
    // section has sub-field "scpFile" and label section has sub-field
    // "mlfFile".
    std::vector<std::wstring> featureNames;
    std::vector<std::wstring> labelNames;
    GetDataNamesFromConfig(readerConfig, featureNames, labelNames);
    if (featureNames.size() + labelNames.size() <= 1)
    {
        RuntimeError("network needs at least 1 input and 1 output specified!");
    }

    // Loads feature files.
    size_t iFeat = 0;
    vector<size_t> numContextLeft;
    vector<size_t> numContextRight;
    vector<msra::asr::FeatureSection*>& scriptpaths = m_trainingOrTestingFeatureSections;
    foreach_index (i, featureNames)
    {
        const ConfigRecordType& thisFeature = readerConfig(featureNames[i]);
        m_featDims.push_back(thisFeature(L"dim"));
        intargvector contextWindow = thisFeature(L"contextWindow", ConfigRecordType::Array(intargvector(vector<int>{1})));
        if (contextWindow.size() == 1) // symmetric
        {
            size_t windowFrames = contextWindow[0];
            if (windowFrames % 2 == 0)
                RuntimeError("augmentationextent: neighbor expansion of input features to %d not symmetrical", (int) windowFrames);
            size_t context = windowFrames / 2; // extend each side by this
            numContextLeft.push_back(context);
            numContextRight.push_back(context);
        }
        else if (contextWindow.size() == 2) // left context, right context
        {
            numContextLeft.push_back(contextWindow[0]);
            numContextRight.push_back(contextWindow[1]);
        }
        else
        {
            RuntimeError("contextFrames must have 1 or 2 values specified, found %d", (int) contextWindow.size());
        }

        // Figures the actual feature dimension, with context.
        m_featDims[i] = m_featDims[i] * (1 + numContextLeft[i] + numContextRight[i]);

        // Figures out the category.
        wstring type = thisFeature(L"type", L"real");
        if (EqualCI(type, L"real"))
        {
            m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
        }
        else
        {
            InvalidArgument("feature type must be 'real'");
        }
        m_featureNameToIdMap[featureNames[i]] = iFeat;
        assert(iFeat == m_featureIdToNameMap.size());
        m_featureIdToNameMap.push_back(featureNames[i]);
        scriptpaths.push_back(new msra::asr::FeatureSection(thisFeature(L"scpFile"), thisFeature(L"rx"), thisFeature(L"featureTransform", L"")));
        m_featureNameToDimMap[featureNames[i]] = m_featDims[i];

        m_featuresBufferMultiIO.push_back(NULL);
        m_featuresBufferAllocatedMultiIO.push_back(0);

        iFeat++;
    }

    // Loads label files.
    size_t iLabel = 0;
    vector<wstring> statelistpaths;
    vector<wstring> mlfpaths;
    vector<vector<wstring>> mlfpathsmulti;
    foreach_index (i, labelNames)
    {
        const ConfigRecordType& thisLabel = readerConfig(labelNames[i]);

        // Figures out label dimension.
        if (thisLabel.Exists(L"labelDim"))
            m_labelDims.push_back(thisLabel(L"labelDim"));
        else if (thisLabel.Exists(L"dim"))
            m_labelDims.push_back(thisLabel(L"dim"));
        else
            InvalidArgument("labels must specify dim or labelDim");

        // Figures out the category.
        wstring type;
        if (thisLabel.Exists(L"labelType"))
            type = (const wstring&) thisLabel(L"labelType"); // let's deprecate this eventually and just use "type"...
        else
            type = (const wstring&) thisLabel(L"type", L"category"); // outputs should default to category
        if (EqualCI(type, L"category"))
            m_nameToTypeMap[labelNames[i]] = InputOutputTypes::category;
        else
            InvalidArgument("label type must be Category");

        // Loads label mapping.
        statelistpaths.push_back(thisLabel(L"labelMappingFile", L""));

        m_labelNameToIdMap[labelNames[i]] = iLabel;
        assert(iLabel == m_labelIdToNameMap.size());
        m_labelIdToNameMap.push_back(labelNames[i]);
        m_labelNameToDimMap[labelNames[i]] = m_labelDims[i];
        mlfpaths.clear();
        mlfpaths.push_back(thisLabel(L"mlfFile"));
        mlfpathsmulti.push_back(mlfpaths);

        m_labelsBufferMultiIO.push_back(NULL);
        m_labelsBufferAllocatedMultiIO.push_back(0);

        iLabel++;

        // Figures out label to target mapping.
        wstring labelToTargetMappingFile(thisLabel(L"labelToTargetMappingFile", L""));
        if (labelToTargetMappingFile != L"")
        {
            std::vector<std::vector<ElemType>> labelToTargetMap;
            m_convertLabelsToTargetsMultiIO.push_back(true);
            if (thisLabel.Exists(L"targetDim"))
            {
                m_labelNameToDimMap[labelNames[i]] = m_labelDims[i] = thisLabel(L"targetDim");
            }
            else
                RuntimeError("output must specify targetDim if labelToTargetMappingFile specified!");
            size_t targetDim = ReadLabelToTargetMappingFile(labelToTargetMappingFile, statelistpaths[i], labelToTargetMap);
            if (targetDim != m_labelDims[i])
                RuntimeError("mismatch between targetDim and dim found in labelToTargetMappingFile");
            m_labelToTargetMapMultiIO.push_back(labelToTargetMap);
        }
        else
        {
            m_convertLabelsToTargetsMultiIO.push_back(false);
            m_labelToTargetMapMultiIO.push_back(std::vector<std::vector<ElemType>>());
        }
    }

    // Sanity check.
    if (iFeat != scriptpaths.size() || iLabel != mlfpathsmulti.size())
        throw std::runtime_error(msra::strfun::strprintf("# of inputs files vs. # of inputs or # of output files vs # of outputs inconsistent\n"));

    // Loads randomization method.
    size_t randomize = randomizeAuto;
    if (readerConfig.Exists(L"randomize"))
    {
        const std::string& randomizeString = readerConfig(L"randomize");
        if (EqualCI(randomizeString, "none"))
        {
            randomize = randomizeNone;
        }
        else if (EqualCI(randomizeString, "auto"))
        {
            randomize = randomizeAuto;
        }
        else
        {
            randomize = readerConfig(L"randomize");
        }
    }

    // Open script files for features.
    size_t numFiles = 0;
    size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing
    vector<wstring> filelist;
    vector<vector<wstring>> infilesmulti;
    foreach_index (i, scriptpaths)
    {
        filelist.clear();
        std::wstring scriptpath = scriptpaths[i]->scpFile;
        fprintf(stderr, "reading script file %S ...", scriptpath.c_str());
        size_t n = 0;
        for (msra::files::textreader reader(scriptpath); reader && filelist.size() <= firstfilesonly /*optimization*/;)
        {
            filelist.push_back(reader.wgetline());
            n++;
        }

        fprintf(stderr, " %lu entries\n", n);

        if (i == 0)
            numFiles = n;
        else if (n != numFiles)
            throw std::runtime_error(msra::strfun::strprintf("number of files in each scriptfile inconsistent (%d vs. %d)", numFiles, n));

        infilesmulti.push_back(filelist);
    }

    // Opens MLF files for labels.
    set<wstring> restrictmlftokeys;
    double htktimetoframe = 100000.0; // default is 10ms
    std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
    int targets_delay = 0;
    if (readerConfig.Exists(L"targets_delay"))
    {
        targets_delay = readerConfig(L"targets_delay");
    }
    foreach_index (i, mlfpathsmulti)
    {
        msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
            labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], htktimetoframe, targets_delay); // label MLF
        // get the temp file name for the page file
        labelsmulti.push_back(labels);
    }

    // Get the readMethod, default value is "blockRandomize", the other
    // option is "rollingWindow". We only support "blockRandomize" in
    // sequence training.
    std::string readMethod(readerConfig(L"readMethod", "blockRandomize"));
    if (EqualCI(readMethod, "blockRandomize"))
    {
        // construct all the parameters we don't need, but need to be passed to the constructor...
        std::pair<std::vector<wstring>, std::vector<wstring>> latticetocs;
        std::unordered_map<std::string, size_t> modelsymmap;
        
        // Note, we are actually not using <m_lattices>, the only reason we
        // kept it was because it was required by
        // <minibatchutterancesourcemulti>.
        m_lattices = new msra::dbn::latticesource(latticetocs, modelsymmap, L"");
        
        // now get the frame source. This has better randomization and doesn't create temp files
        m_frameSource = new msra::dbn::minibatchutterancesourcemulti(
            scriptpaths, infilesmulti, labelsmulti, m_featDims, m_labelDims,
            numContextLeft, numContextRight, randomize, *m_lattices, m_latticeMap, m_framemode);
    }
    else if (EqualCI(readMethod, "rollingWindow"))
    {
        // "rollingWindow" is not supported in sequence training.
        if (m_doSeqTrain)
        {
            LogicError("rollingWindow is not supported in sequence training.\n");
        }
        std::wstring pageFilePath;
        std::vector<std::wstring> pagePaths;
        if (readerConfig.Exists(L"pageFilePath"))
        {
            pageFilePath = (const wstring&) readerConfig(L"pageFilePath");

            // replace any '/' with '\' for compat with default path
            std::replace(pageFilePath.begin(), pageFilePath.end(), '/', '\\');
#ifdef _WIN32
            // verify path exists
            DWORD attrib = GetFileAttributes(pageFilePath.c_str());
            if (attrib == INVALID_FILE_ATTRIBUTES || !(attrib & FILE_ATTRIBUTE_DIRECTORY))
                throw std::runtime_error("pageFilePath does not exist");
#endif
#ifdef __unix__
            struct stat statbuf;
            if (stat(wtocharpath(pageFilePath).c_str(), &statbuf) == -1)
            {
                RuntimeError("pageFilePath does not exist");
            }

#endif
        }
        else // using default temporary path
        {
#ifdef _WIN32
            pageFilePath.reserve(MAX_PATH);
            GetTempPath(MAX_PATH, &pageFilePath[0]);
#endif
#ifdef __unix__
            pageFilePath.reserve(PATH_MAX);
            pageFilePath = L"/tmp/temp.CNTK.XXXXXX";
#endif
        }

#ifdef _WIN32
        if (pageFilePath.size() > MAX_PATH - 14) // max length of input to GetTempFileName is PATH_MAX-14
            throw std::runtime_error(msra::strfun::strprintf("pageFilePath must be less than %d characters", MAX_PATH - 14));
#endif
#ifdef __unix__
        if (pageFilePath.size() > PATH_MAX - 14) // max length of input to GetTempFileName is PATH_MAX-14
            throw std::runtime_error(msra::strfun::strprintf("pageFilePath must be less than %d characters", PATH_MAX - 14));
#endif
        foreach_index (i, infilesmulti)
        {
#ifdef _WIN32
            wchar_t tempFile[MAX_PATH];
            GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
            pagePaths.push_back(tempFile);
#endif
#ifdef __unix__
            char* tempFile;
            // GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
            tempFile = (char*) pageFilePath.c_str();
            int fid = mkstemp(tempFile);
            unlink(tempFile);
            close(fid);
            pagePaths.push_back(GetWC(tempFile));
#endif
        }

        const bool mayhavenoframe = false;
        int addEnergy = 0;

        // m_frameSourceMultiIO = new msra::dbn::minibatchframesourcemulti(infilesmulti, labelsmulti, m_featDims, m_labelDims, randomize, pagepath, mayhavenoframe, addEnergy);
        // m_frameSourceMultiIO->setverbosity(verbosity);
        int verbosity = readerConfig(L"verbosity", 2);
        m_frameSource = new msra::dbn::minibatchframesourcemulti(scriptpaths, infilesmulti, labelsmulti, m_featDims, m_labelDims, numContextLeft, numContextRight, randomize, pagePaths, mayhavenoframe, addEnergy);
        m_frameSource->setverbosity(verbosity);
    }
    else
    {
        RuntimeError("readMethod must be rollingWindow or blockRandomize");
    }
}