bool getbatch()

in Source/Readers/HTKMLFReader/utterancesourcemulti.h [1575:1879]


    bool getbatch(const size_t globalts, const size_t framesrequested,
                  const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
                  std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
                  std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
                  std::vector<std::shared_ptr<const latticesource::latticepair>> &latticepairs, std::vector<std::vector<size_t>> &sentendmark,
                  std::vector<std::vector<size_t>> &phoneboundaries2) override
    {
        bool readfromdisk = false; // return value: shall be 'true' if we paged in anything

        auto_timer timergetbatch;
        assert(_totalframes > 0);

        // update randomization if a new sweep is entered  --this is a complex operation that updates many of the data members used below
        const size_t sweep = lazyrandomization(globalts);

        size_t mbframes = 0;
        const std::vector<char> noboundaryflags; // dummy
        if (!framemode)                          // regular utterance mode
        {
            // find utterance position for globalts
            // There must be a precise match; it is not possible to specify frames that are not on boundaries.
            auto positer = randomizedutteranceposmap.find(globalts);
            if (positer == randomizedutteranceposmap.end())
                LogicError("getbatch: invalid 'globalts' parameter; must match an existing utterance boundary");
            const size_t spos = positer->second;

            // determine how many utterances will fit into the requested minibatch size
            mbframes = randomizedutterancerefs[spos].numframes; // at least one utterance, even if too long
            size_t epos;
            for (epos = spos + 1; epos < numutterances && ((mbframes + randomizedutterancerefs[epos].numframes) < framesrequested); epos++) // add more utterances as long as they fit within requested minibatch size
                mbframes += randomizedutterancerefs[epos].numframes;

            // do some paging housekeeping
            // This will also set the feature-kind information if it's the first time.
            // Free all chunks left of the range.
            // Page-in all chunks right of the range.
            // We are a little more blunt for now: Free all outside the range, and page in only what is touched. We could save some loop iterations.
            const size_t windowbegin = positionchunkwindows[spos].windowbegin();
            const size_t windowend = positionchunkwindows[epos - 1].windowend();
            for (size_t k = 0; k < windowbegin; k++)
                releaserandomizedchunk(k);
            for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
                releaserandomizedchunk(k);
            for (size_t pos = spos; pos < epos; pos++)
                if ((randomizedutterancerefs[pos].chunkindex % numsubsets) == subsetnum)
                    readfromdisk |= requirerandomizedchunk(randomizedutterancerefs[pos].chunkindex, windowbegin, windowend); // (window range passed in for checking only)

            // Note that the above loop loops over all chunks incl. those that we already should have.
            // This has an effect, e.g., if 'numsubsets' has changed (we will fill gaps).

            // determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode
            size_t tspos = 0;
            for (size_t pos = spos; pos < epos; pos++)
            {
                const auto &uttref = randomizedutterancerefs[pos];
                if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
                    continue;

                tspos += uttref.numframes;
            }

            // resize feat and uids
            feat.resize(vdim.size());
            uids.resize(classids.size());
            if (m_generatePhoneBoundaries)
                phoneboundaries2.resize(classids.size());
            sentendmark.resize(vdim.size());
            assert(feat.size() == vdim.size());
            assert(feat.size() == randomizedchunks.size());
            foreach_index (i, feat)
            {
                feat[i].resize(vdim[i], tspos);

                if (i == 0)
                {
                    foreach_index (j, uids)
                    {
                        if (issupervised()) // empty means unsupervised training -> return empty uids
                        {
                            uids[j].resize(tspos);
                            if (m_generatePhoneBoundaries)
                                phoneboundaries2[j].resize(tspos);
                        }
                        else
                        {
                            uids[i].clear();
                            if (m_generatePhoneBoundaries)
                                phoneboundaries2[i].clear();
                        }
                        latticepairs.clear(); // will push_back() below
                        transcripts.clear();
                    }
                    foreach_index (j, sentendmark)
                    {
                        sentendmark[j].clear();
                    }
                }
            }
            // return these utterances
            if (verbosity > 0)
                fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int) spos, (int) (epos - 1), (int) tspos, (int) mbframes, (int) framesrequested, (int) sweep);
            tspos = 0; // relative start of utterance 'pos' within the returned minibatch
            for (size_t pos = spos; pos < epos; pos++)
            {
                const auto &uttref = randomizedutterancerefs[pos];
                if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node
                    continue;

                size_t n = 0;
                foreach_index (i, randomizedchunks)
                {
                    const auto &chunk = randomizedchunks[i][uttref.chunkindex];
                    const auto &chunkdata = chunk.getchunkdata();
                    assert((numsubsets > 1) || (uttref.globalts == globalts + tspos));
                    auto uttframes = chunkdata.getutteranceframes(uttref.utteranceindex());
                    matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[j].size() and m[j][i] as required by augmentneighbors())
                    n = uttframevectors.size();
                    sentendmark[i].push_back(n + tspos);
                    assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n);

                    // copy the frames and class labels
                    for (size_t t = 0; t < n; t++) // t = time index into source utterance
                    {
                        size_t leftextent, rightextent;
                        // page in the needed range of frames
                        if (leftcontext[i] == 0 && rightcontext[i] == 0)
                        {
                            leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
                        }
                        else
                        {
                            leftextent = leftcontext[i];
                            rightextent = rightcontext[i];
                        }
                        augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], t + tspos);
                        // augmentneighbors(uttframevectors, noboundaryflags, t, feat[i], t + tspos);
                    }

                    // copy the frames and class labels
                    if (i == 0)
                    {
                        auto uttclassids = getclassids(uttref);
                        std::vector<shiftedvector<biggrowablevector<HMMIDTYPE>>> uttphoneboudaries;
                        if (m_generatePhoneBoundaries)
                            uttphoneboudaries = getphonebound(uttref);
                        foreach_index (j, uttclassids)
                        {
                            for (size_t t = 0; t < n; t++) // t = time index into source utterance
                            {
                                if (issupervised())
                                {
                                    uids[j][t + tspos] = uttclassids[j][t];
                                    if (m_generatePhoneBoundaries)
                                        phoneboundaries2[j][t + tspos] = uttphoneboudaries[j][t];
                                }
                            }

                            if (!this->lattices.empty())
                            {
                                auto latticepair = chunkdata.getutterancelattice(uttref.utteranceindex());
                                latticepairs.push_back(latticepair);
                                // look up reference
                                const auto &key = latticepair->getkey();
                                if (!allwordtranscripts.empty())
                                {
                                    const auto &transcript = allwordtranscripts.find(key)->second;
                                    transcripts.push_back(transcript.words);
                                }
                            }
                        }
                    }
                }
                tspos += n;
            }

            foreach_index (i, feat)
            {
                assert(tspos == feat[i].cols());
            }
        }
        else
        {
            const size_t sweepts = sweep * _totalframes;                      // first global frame index for this sweep
            const size_t sweepte = sweepts + _totalframes;                    // and its end
            const size_t globalte = std::min(globalts + framesrequested, sweepte); // we return as much as requested, but not exceeding sweep end
            mbframes = globalte - globalts;                                   // that's our mb size

            // Perform randomization of the desired frame range
            m_frameRandomizer.randomizeFrameRange(globalts, globalte);

            // determine window range
            // We enumerate all frames--can this be done more efficiently?
            const size_t firstchunk = chunkforframepos(globalts);
            const size_t lastchunk = chunkforframepos(globalte - 1);
            const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin;
            const size_t windowend = randomizedchunks[0][lastchunk].windowend;
            if (verbosity > 0)
                fprintf(stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n",
                        (int) globalts, (int) globalte, (int) mbframes, (int) framesrequested, (int) sweep, (int) firstchunk, (int) lastchunk, (int) windowbegin, (int) windowend);
            // release all data outside, and page in all data inside
            for (size_t k = 0; k < windowbegin; k++)
                releaserandomizedchunk(k);
            for (size_t k = windowbegin; k < windowend; k++)
                if ((k % numsubsets) == subsetnum)                                     // in MPI mode, we skip chunks this way
                    readfromdisk |= requirerandomizedchunk(k, windowbegin, windowend); // (window range passed in for checking only, redundant here)
            for (size_t k = windowend; k < randomizedchunks[0].size(); k++)
                releaserandomizedchunk(k);

            // determine the true #frames we return--it is less than mbframes in the case of MPI/data-parallel sub-set mode
            // First determine it for all nodes, then pick the min over all nodes, as to give all the same #frames for better load balancing.
            // TODO: No, return all; and leave it to caller to redistribute them [Zhijie Yan]
            std::vector<size_t> subsetsizes(numsubsets, 0);
            for (size_t i = 0; i < mbframes; i++) // i is input frame index; j < i in case of MPI/data-parallel sub-set mode
            {
                const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + i);
                subsetsizes[frameref.chunkindex % numsubsets]++;
            }
            size_t j = subsetsizes[subsetnum];                                           // return what we have  --TODO: we can remove the above full computation again now
            const size_t allocframes = std::max(j, (mbframes + numsubsets - 1) / numsubsets); // we leave space for the desired #frames, assuming caller will try to pad them later

            // resize feat and uids
            feat.resize(vdim.size());
            uids.resize(classids.size());
            assert(feat.size() == vdim.size());
            assert(feat.size() == randomizedchunks.size());
            foreach_index (i, feat)
            {
                feat[i].resize(vdim[i], allocframes);
                feat[i].shrink(vdim[i], j);

                if (i == 0)
                {
                    foreach_index (k, uids)
                    {
                        if (issupervised()) // empty means unsupervised training -> return empty uids
                            uids[k].resize(j);
                        else
                            uids[k].clear();
                        latticepairs.clear(); // will push_back() below
                        transcripts.clear();
                    }
                }
            }

            // return randomized frames for the time range of those utterances
            size_t currmpinodeframecount = 0;
            for (size_t j2 = 0; j2 < mbframes; j2++)
            {
                if (currmpinodeframecount >= feat[0].cols()) // MPI/data-parallel mode: all nodes return the same #frames, which is how feat(,) is allocated
                    break;

                // map to time index inside arrays
                const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + j2);

                // in MPI/data-parallel mode, skip frames that are not in chunks loaded for this MPI node
                if ((frameref.chunkindex % numsubsets) != subsetnum)
                    continue;

                // random utterance
                readfromdisk |= requirerandomizedchunk(frameref.chunkindex, windowbegin, windowend); // (this is just a check; should not actually page in anything)

                foreach_index (i, randomizedchunks)
                {
                    const auto &chunk = randomizedchunks[i][frameref.chunkindex];
                    const auto &chunkdata = chunk.getchunkdata();
                    auto uttframes = chunkdata.getutteranceframes(frameref.utteranceindex());
                    matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[.].size() and m[.][.] as required by augmentneighbors())
                    const size_t n = uttframevectors.size();
                    assert(n == uttframes.cols() && chunkdata.numframes(frameref.utteranceindex()) == n);
                    n;

                    // copy frame and class labels
                    const size_t t = frameref.frameindex();

                    size_t leftextent, rightextent;
                    // page in the needed range of frames
                    if (leftcontext[i] == 0 && rightcontext[i] == 0)
                    {
                        leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]);
                    }
                    else
                    {
                        leftextent = leftcontext[i];
                        rightextent = rightcontext[i];
                    }
                    augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], currmpinodeframecount);

                    if (issupervised() && i == 0)
                    {
                        auto frameclassids = getclassids(frameref);
                        foreach_index (k, uids)
                            uids[k][currmpinodeframecount] = frameclassids[k][t];
                    }
                }

                currmpinodeframecount++;
            }
        }
        timegetbatch = timergetbatch;

        // this is the number of frames we actually moved ahead in time
        framesadvanced = mbframes;

        return readfromdisk;
    }