in Source/Readers/HTKMLFReader/HTKMLFReader.cpp [108:603]
void HTKMLFReader<ElemType>::PrepareForTrainingOrTesting(const ConfigRecordType& readerConfig)
{
vector<wstring> scriptpaths;
vector<wstring> RootPathInScripts;
wstring RootPathInLatticeTocs;
vector<wstring> mlfpaths;
vector<vector<wstring>> mlfpathsmulti;
size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing
vector<vector<wstring>> infilesmulti;
size_t numFiles;
wstring unigrampath(L"");
size_t randomize = randomizeAuto;
size_t iFeat, iLabel;
iFeat = iLabel = 0;
vector<wstring> statelistpaths;
vector<size_t> numContextLeft;
vector<size_t> numContextRight;
size_t numExpandToUtt = 0;
std::vector<std::wstring> featureNames;
std::vector<std::wstring> labelNames;
// for hmm and lattice
std::vector<std::wstring> hmmNames;
std::vector<std::wstring> latticeNames;
GetDataNamesFromConfig(readerConfig, featureNames, labelNames, hmmNames, latticeNames);
if (featureNames.size() + labelNames.size() <= 1)
{
InvalidArgument("network needs at least 1 input and 1 output specified!");
}
// load data for all real-valued inputs (features)
foreach_index (i, featureNames)
{
const ConfigRecordType& thisFeature = readerConfig(featureNames[i]);
m_featDims.push_back(thisFeature(L"dim"));
bool expandToUtt = thisFeature(L"expandToUtterance", false); // should feature be processed as an ivector?
m_expandToUtt.push_back(expandToUtt);
if (expandToUtt)
numExpandToUtt++;
intargvector contextWindow = thisFeature(L"contextWindow", ConfigRecordType::Array(intargvector(vector<int>{1})));
if (contextWindow.size() == 1) // symmetric
{
size_t windowFrames = contextWindow[0];
if (windowFrames % 2 == 0)
InvalidArgument("augmentationextent: neighbor expansion of input features to %d not symmetrical", (int) windowFrames);
size_t context = windowFrames / 2; // extend each side by this
numContextLeft.push_back(context);
numContextRight.push_back(context);
}
else if (contextWindow.size() == 2) // left context, right context
{
numContextLeft.push_back(contextWindow[0]);
numContextRight.push_back(contextWindow[1]);
}
else
{
InvalidArgument("contextFrames must have 1 or 2 values specified, found %d", (int) contextWindow.size());
}
if (expandToUtt && (numContextLeft[i] != 0 || numContextRight[i] != 0))
RuntimeError("contextWindow expansion not permitted when expandToUtterance=true");
// update m_featDims to reflect the total input dimension (featDim x contextWindow), not the native feature dimension
// that is what the lower level feature readers expect
m_featDims[i] = m_featDims[i] * (1 + numContextLeft[i] + numContextRight[i]);
wstring type = thisFeature(L"type", L"real");
if (EqualCI(type, L"real"))
{
m_nameToTypeMap[featureNames[i]] = InputOutputTypes::real;
}
else
{
InvalidArgument("feature type must be 'real'");
}
m_featureNameToIdMap[featureNames[i]] = iFeat;
wstring type2 = thisFeature(L"scpFile");
scriptpaths.push_back(type2);
RootPathInScripts.push_back(thisFeature(L"prefixPathInSCP", L""));
m_featureNameToDimMap[featureNames[i]] = m_featDims[i];
m_featuresBufferMultiIO.push_back(nullptr);
m_featuresBufferAllocatedMultiIO.push_back(0);
iFeat++;
}
foreach_index (i, labelNames)
{
const ConfigRecordType& thisLabel = readerConfig(labelNames[i]);
if (thisLabel.Exists(L"labelDim"))
m_labelDims.push_back(thisLabel(L"labelDim"));
else if (thisLabel.Exists(L"dim"))
m_labelDims.push_back(thisLabel(L"dim"));
else
InvalidArgument("labels must specify dim or labelDim");
wstring type;
if (thisLabel.Exists(L"labelType"))
type = (const wstring&) thisLabel(L"labelType"); // let's deprecate this eventually and just use "type"...
else
type = (const wstring&) thisLabel(L"type", L"category"); // outputs should default to category
if (EqualCI(type, L"category"))
m_nameToTypeMap[labelNames[i]] = InputOutputTypes::category;
else
InvalidArgument("label type must be 'category'");
statelistpaths.push_back(thisLabel(L"labelMappingFile", L""));
m_labelNameToIdMap[labelNames[i]] = iLabel;
m_labelNameToDimMap[labelNames[i]] = m_labelDims[i];
mlfpaths.clear();
if (thisLabel.ExistsCurrent(L"mlfFile"))
{
wstring type2 = thisLabel(L"mlfFile");
mlfpaths.push_back(type2);
}
else
{
if (!thisLabel.ExistsCurrent(L"mlfFileList"))
{
InvalidArgument("Either mlfFile or mlfFileList must exist in HTKMLFReder");
}
wstring list = thisLabel(L"mlfFileList");
for (msra::files::textreader r(list); r;)
{
mlfpaths.push_back(r.wgetline());
}
}
mlfpathsmulti.push_back(mlfpaths);
m_labelsBufferMultiIO.push_back(nullptr);
m_labelsBufferAllocatedMultiIO.push_back(0);
iLabel++;
wstring labelToTargetMappingFile(thisLabel(L"labelToTargetMappingFile", L""));
if (labelToTargetMappingFile != L"")
{
std::vector<std::vector<ElemType>> labelToTargetMap;
m_convertLabelsToTargetsMultiIO.push_back(true);
if (thisLabel.Exists(L"targetDim"))
{
m_labelNameToDimMap[labelNames[i]] = m_labelDims[i] = thisLabel(L"targetDim");
}
else
{
RuntimeError("output must specify targetDim if labelToTargetMappingFile specified!");
}
size_t targetDim = ReadLabelToTargetMappingFile(labelToTargetMappingFile, statelistpaths[i], labelToTargetMap);
if (targetDim != m_labelDims[i])
RuntimeError("mismatch between targetDim and dim found in labelToTargetMappingFile");
m_labelToTargetMapMultiIO.push_back(labelToTargetMap);
}
else
{
m_convertLabelsToTargetsMultiIO.push_back(false);
m_labelToTargetMapMultiIO.push_back(std::vector<std::vector<ElemType>>());
}
}
// get lattice toc file names
std::pair<std::vector<wstring>, std::vector<wstring>> latticetocs;
foreach_index (i, latticeNames) // only support one set of lattice now
{
const ConfigRecordType& thisLattice = readerConfig(latticeNames[i]);
vector<wstring> paths;
expand_wildcards(thisLattice(L"denLatTocFile"), paths);
latticetocs.second.insert(latticetocs.second.end(), paths.begin(), paths.end());
if (thisLattice.Exists(L"numLatTocFile"))
{
paths.clear();
expand_wildcards(thisLattice(L"numLatTocFile"), paths);
latticetocs.first.insert(latticetocs.first.end(), paths.begin(), paths.end());
}
RootPathInLatticeTocs = (wstring) thisLattice(L"prefixPathInToc", L"");
}
// get HMM related file names
vector<wstring> cdphonetyingpaths, transPspaths;
foreach_index (i, hmmNames)
{
const ConfigRecordType& thisHMM = readerConfig(hmmNames[i]);
wstring type2 = thisHMM(L"phoneFile");
cdphonetyingpaths.push_back(type2);
transPspaths.push_back(thisHMM(L"transPFile", L""));
}
// mmf files
// only support one set now
if (cdphonetyingpaths.size() > 0 && statelistpaths.size() > 0 && transPspaths.size() > 0)
m_hset.loadfromfile(cdphonetyingpaths[0], statelistpaths[0], transPspaths[0]);
if (iFeat != scriptpaths.size() || iLabel != mlfpathsmulti.size())
RuntimeError("# of inputs files vs. # of inputs or # of output files vs # of outputs inconsistent");
if (iFeat == numExpandToUtt)
RuntimeError("At least one feature stream must be frame-based, not utterance-based");
if (m_expandToUtt[0]) // first feature stream is ivector type - that will mess up lower level feature reader
RuntimeError("The first feature stream in the file must be frame-based not utterance based. Please reorder the feature blocks of your config appropriately");
if (readerConfig.Exists(L"randomize"))
{
wstring randomizeString = readerConfig.CanBeString(L"randomize") ? readerConfig(L"randomize") : wstring();
if (EqualCI(randomizeString, L"none")) randomize = randomizeNone;
else if (EqualCI(randomizeString, L"auto")) randomize = randomizeAuto;
else randomize = readerConfig(L"randomize"); // TODO: could this not just be randomizeString?
}
m_frameMode = readerConfig(L"frameMode", true);
m_verbosity = readerConfig(L"verbosity", 0);
if (m_frameMode && m_truncated)
{
InvalidArgument("'Truncated' cannot be 'true' in frameMode (i.e. when 'frameMode' is 'true')");
}
// determine if we partial minibatches are desired
wstring minibatchMode(readerConfig(L"minibatchMode", L"partial"));
m_partialMinibatch = EqualCI(minibatchMode, L"partial");
// get the read method, defaults to "blockRandomize" other option is "rollingWindow"
wstring readMethod(readerConfig(L"readMethod", L"blockRandomize"));
if (readMethod == L"blockRandomize" && randomize == randomizeNone)
InvalidArgument("'randomize' cannot be 'none' when 'readMethod' is 'blockRandomize'.");
if (readMethod == L"rollingWindow" && numExpandToUtt>0)
RuntimeError("rollingWindow reader does not support expandToUtt. Change to blockRandomize.");
// read all input files (from multiple inputs)
// TO DO: check for consistency (same number of files in each script file)
numFiles = 0;
foreach_index (i, scriptpaths)
{
vector<wstring> filelist;
std::wstring scriptpath = scriptpaths[i];
fprintf(stderr, "reading script file %ls ...", scriptpath.c_str());
size_t n = 0;
for (msra::files::textreader reader(scriptpath); reader && filelist.size() <= firstfilesonly /*optimization*/;)
{
filelist.push_back(reader.wgetline());
n++;
}
fprintf(stderr, " %lu entries\n", (unsigned long)n);
if (i == 0)
numFiles = n;
else if (n != numFiles)
RuntimeError("number of files in each scriptfile inconsistent (%d vs. %d)", (int) numFiles, (int) n);
// post processing file list :
// - if users specified PrefixPath, add the prefix to each of path in filelist
// - else do the dotdotdot expansion if necessary
wstring rootpath = RootPathInScripts[i];
if (!rootpath.empty()) // use has specified a path prefix for this feature
{
// first make slash consistent (sorry for linux users:this is not necessary for you)
std::replace(rootpath.begin(), rootpath.end(), L'\\', L'/');
// second, remove trailing slash if there is any
// TODO: when gcc -v is 4.9 or greater, this should be: std::regex_replace(rootpath, L"\\/+$", wstring());
int stringPos = 0;
for (stringPos = (int) (rootpath.length() - 1); stringPos >= 0; stringPos--)
{
if (rootpath[stringPos] != L'/')
{
break;
}
}
rootpath = rootpath.substr(0, stringPos + 1);
// third, join the rootpath with each entry in filelist
if (!rootpath.empty())
{
for (wstring& path : filelist)
{
if (path.find_first_of(L'=') != wstring::npos)
{
vector<wstring> strarr = msra::strfun::split(path, L"=");
#ifdef WIN32
replace(strarr[1].begin(), strarr[1].end(), L'\\', L'/');
#endif
path = strarr[0] + L"=" + rootpath + L"/" + strarr[1];
}
else
{
#ifdef WIN32
replace(path.begin(), path.end(), L'\\', L'/');
#endif
path = rootpath + L"/" + path;
}
}
}
}
else
{
/*
do "..." expansion if SCP uses relative path names
"..." in the SCP means full path is the same as the SCP file
for example, if scp file is "//aaa/bbb/ccc/ddd.scp"
and contains entry like
.../file1.feat
.../file2.feat
etc.
the features will be read from
// aaa/bbb/ccc/file1.feat
// aaa/bbb/ccc/file2.feat
etc.
This works well if you store the scp file with the features but
do not want different scp files everytime you move or create new features
*/
wstring scpdircached;
for (auto& entry : filelist)
ExpandDotDotDot(entry, scriptpath, scpdircached);
}
infilesmulti.push_back(std::move(filelist));
}
if (readerConfig.Exists(L"unigram"))
unigrampath = (const wstring&) readerConfig(L"unigram");
// load a unigram if needed (this is used for MMI training)
msra::lm::CSymbolSet unigramsymbols;
std::unique_ptr<msra::lm::CMGramLM> unigram;
size_t silencewordid = SIZE_MAX;
size_t startwordid = SIZE_MAX;
size_t endwordid = SIZE_MAX;
if (unigrampath != L"")
{
unigram.reset(new msra::lm::CMGramLM());
unigram->read(unigrampath, unigramsymbols, false /*filterVocabulary--false will build the symbol map*/, 1 /*maxM--unigram only*/);
silencewordid = unigramsymbols["!silence"]; // give this an id (even if not in the LM vocabulary)
startwordid = unigramsymbols["<s>"];
endwordid = unigramsymbols["</s>"];
}
if (!unigram && latticetocs.second.size() > 0)
fprintf(stderr, "trainlayer: OOV-exclusion code enabled, but no unigram specified to derive the word set from, so you won't get OOV exclusion\n");
// currently assumes all mlfs will have same root name (key)
set<wstring> restrictmlftokeys; // restrict MLF reader to these files--will make stuff much faster without having to use shortened input files
if (infilesmulti[0].size() <= 100)
{
foreach_index (i, infilesmulti[0])
{
msra::asr::htkfeatreader::parsedpath ppath(infilesmulti[0][i]);
const wstring ppathStr = (wstring) ppath;
// delete extension (or not if none)
// TODO: when gcc -v is 4.9 or greater, this should be: regex_replace((wstring)ppath, wregex(L"\\.[^\\.\\\\/:]*$"), wstring());
int stringPos = 0;
for (stringPos = (int) ppathStr.length() - 1; stringPos >= 0; stringPos--)
{
if (ppathStr[stringPos] == L'.' || ppathStr[stringPos] == L'\\' || ppathStr[stringPos] == L'/' || ppathStr[stringPos] == L':')
{
break;
}
}
if (ppathStr[stringPos] == L'.') {
restrictmlftokeys.insert(ppathStr.substr(0, stringPos));
}
else
{
restrictmlftokeys.insert(ppathStr);
}
}
}
// get labels
// if (readerConfig.Exists(L"statelist"))
// statelistpath = readerConfig(L"statelist");
double htktimetoframe = 100000.0; // default is 10ms
// std::vector<msra::asr::htkmlfreader<msra::asr::htkmlfentry,msra::lattices::lattice::htkmlfwordsequence>> labelsmulti;
std::vector<std::map<std::wstring, std::vector<msra::asr::htkmlfentry>>> labelsmulti;
// std::vector<std::wstring> pagepath;
foreach_index (i, mlfpathsmulti)
{
const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL;
msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>
labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map<string, size_t>*) NULL, htktimetoframe); // label MLF
// get the temp file name for the page file
// Make sure 'msra::asr::htkmlfreader' type has a move constructor
static_assert(std::is_move_constructible<msra::asr::htkmlfreader<msra::asr::htkmlfentry, msra::lattices::lattice::htkmlfwordsequence>>::value,
"Type 'msra::asr::htkmlfreader' should be move constructible!");
labelsmulti.push_back(std::move(labels));
}
if (EqualCI(readMethod, L"blockRandomize"))
{
// construct all the parameters we don't need, but need to be passed to the constructor...
m_lattices.reset(new msra::dbn::latticesource(latticetocs, m_hset.getsymmap(), RootPathInLatticeTocs));
m_lattices->setverbosity(m_verbosity);
// now get the frame source. This has better randomization and doesn't create temp files
bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false);
m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims,
numContextLeft, numContextRight, randomize,
*m_lattices, m_latticeMap, m_frameMode,
m_expandToUtt, m_maxUtteranceLength, m_truncated));
m_frameSource->setverbosity(m_verbosity);
}
else if (EqualCI(readMethod, L"rollingWindow"))
{
std::wstring pageFilePath;
std::vector<std::wstring> pagePaths;
if (readerConfig.Exists(L"pageFilePath"))
{
pageFilePath = (const wstring&) readerConfig(L"pageFilePath");
#ifdef _WIN32
// replace any '/' with '\' for compat with default path
std::replace(pageFilePath.begin(), pageFilePath.end(), '/', '\\');
// verify path exists
DWORD attrib = GetFileAttributes(pageFilePath.c_str());
if (attrib == INVALID_FILE_ATTRIBUTES || !(attrib & FILE_ATTRIBUTE_DIRECTORY))
RuntimeError("pageFilePath does not exist");
#endif
#ifdef __unix__
struct stat statbuf;
if (stat(wtocharpath(pageFilePath).c_str(), &statbuf) == -1)
{
RuntimeError("pageFilePath does not exist");
}
#endif
}
else // using default temporary path
{
#ifdef _WIN32
pageFilePath.reserve(MAX_PATH);
GetTempPath(MAX_PATH, &pageFilePath[0]);
#endif
#ifdef __unix__
pageFilePath = L"/tmp/temp.CNTK.XXXXXX";
#endif
}
#ifdef _WIN32
if (pageFilePath.size() > MAX_PATH - 14) // max length of input to GetTempFileName is MAX_PATH-14
RuntimeError("pageFilePath must be less than %d characters", MAX_PATH - 14);
#else
if (pageFilePath.size() > PATH_MAX - 14) // max length of input to GetTempFileName is PATH_MAX-14
RuntimeError("pageFilePath must be less than %d characters", PATH_MAX - 14);
#endif
foreach_index (i, infilesmulti)
{
#ifdef _WIN32
wchar_t tempFile[MAX_PATH];
GetTempFileName(pageFilePath.c_str(), L"CNTK", 0, tempFile);
pagePaths.push_back(tempFile);
#endif
#ifdef __unix__
char tempFile[PATH_MAX];
strcpy(tempFile, Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(pageFilePath)).c_str());
int fid = mkstemp(tempFile);
unlink(tempFile);
close(fid);
pagePaths.push_back(GetWC(tempFile));
#endif
}
const bool mayhavenoframe = false;
int addEnergy = 0;
m_frameSource.reset(new msra::dbn::minibatchframesourcemulti(infilesmulti, labelsmulti, m_featDims, m_labelDims,
numContextLeft, numContextRight, randomize,
pagePaths, mayhavenoframe, addEnergy));
m_frameSource->setverbosity(m_verbosity);
}
else
{
RuntimeError("readMethod must be 'rollingWindow' or 'blockRandomize'");
}
}