size_t SGD::TrainOneEpoch()

in Source/SGDLib/SGD.cpp [1005:1750]


size_t SGD<ElemType>::TrainOneEpoch(ComputationNetworkPtr net,
                                    ComputationNetworkPtr refNet,
                                    const ComputationNodeBasePtr& refNode,
                                    const int epochNumber,
                                    const size_t epochSize,
                                    IDataReader* trainSetDataReader,
                                    const double learnRatePerSample,
                                    size_t tunedMBSize,
                                    const std::vector<ComputationNodeBasePtr>& featureNodes,
                                    const std::vector<ComputationNodeBasePtr>& labelNodes,
                                    const std::vector<ComputationNodeBasePtr>& criterionNodes,
                                    const std::vector<ComputationNodeBasePtr>& evaluationNodes,
                                    StreamMinibatchInputs* inputMatrices, // TODO: why is this a pointer?
                                    const std::list<ComputationNodeBasePtr>& learnableNodes,
                                    std::list<MatrixBasePtr>& smoothedGradients, vector<double>& smoothedCounts,
                                    /*out*/ EpochCriterion& epochCriterion,
                                    /*out*/ std::vector<EpochCriterion>& epochEvalErrors,
                                    const std::string& prefixMsg,
                                    const size_t maxNumberOfSamples,
                                    const size_t totalMBsSeenBefore,
                                    ::CNTK::Internal::TensorBoardFileWriterPtr tensorBoardWriter,
                                    const int startEpoch)
{
    PROFILE_SCOPE(profilerEvtMainEpoch);

    ScopedNetworkOperationMode modeGuard(net, NetworkOperationMode::training);

    // bring our 'out' values into consistent state
    epochCriterion = EpochCriterion(0);
    epochEvalErrors.assign(epochEvalErrors.size(), EpochCriterion(0));

    double totalTimeInMBs = 0; // use double since timer has sub-microsecond time resolution

    // initialize statistics
    size_t totalEpochSamples = 0;

    int numMBsRun = 0;
    int numMBsRunSinceLastLogged = 0;

    bool useGradientAggregation = UsingGradientAggregation(epochNumber);
    bool useModelAggregation = UsingModelAggregation(epochNumber);
    bool useAsyncGradientAggregation = UsingAsyncGradientAggregation(epochNumber);
    bool useParallelTrain = UsingParallelTrain(epochNumber);

    // Find all evaluation nodes that accumulate error on their own.
    auto evaluationNodesWhichAccumulateResult = net->ExtractNodesWhichAccumulateResult(
        set<ComputationNodeBasePtr>(evaluationNodes.begin(), evaluationNodes.end()));
    auto ContainsAccumulatedResult = [&evaluationNodesWhichAccumulateResult](ComputationNodeBasePtr node) {
        return evaluationNodesWhichAccumulateResult.find(node) != evaluationNodesWhichAccumulateResult.end();
    };

    // MA-related variables
    size_t nSamplesSinceLastModelSync = 0;
    size_t blockSizePerWorker = 0;
    if (useParallelTrain && m_pMASGDHelper)
    {
        m_pMASGDHelper->OnEpochStart(learnableNodes);
        blockSizePerWorker = m_modelAggregationBlockSize / m_mpi->NumNodesInUse();
    }

    std::vector<Matrix<ElemType>*> learnParamsGradients;
    Profiler profiler(m_numMBsToCUDAProfile);

    // resetting this, so profiling is performed for one epoch only
    m_numMBsToCUDAProfile = 0;

    bool useDistributedMBReading = useParallelTrain &&
                                   m_enableDistributedMBReading &&
                                   trainSetDataReader->SupportsDistributedMBRead();
    if (useDistributedMBReading)
    {
        trainSetDataReader->StartDistributedMinibatchLoop(tunedMBSize, epochNumber, m_mpi->CurrentNodeRank(),
            m_mpi->NumNodesInUse(), inputMatrices->GetStreamDescriptions(), epochSize);
    }
    else
    {
        trainSetDataReader->StartMinibatchLoop(tunedMBSize, epochNumber, inputMatrices->GetStreamDescriptions(), epochSize);
    }

    net->StartEvaluateMinibatchLoop(evaluationNodes);
    net->StartEvaluateMinibatchLoop(criterionNodes);
    if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode)
    {
        refNet->StartEvaluateMinibatchLoop(refNode);
    }

    // prepare for sub-minibatching
    // Sub-minibatching is used if a single minibatch is too large to fit into GPU RAM.
    DataReaderHelpers::SubminibatchDispatcher<ElemType> smbDispatcher;
    size_t numSubminibatchesNeeded = DataReaderHelpers::GetNumSubminibatchesNeeded<ElemType>(trainSetDataReader, m_maxSamplesInRAM, m_numSubminiBatches, tunedMBSize);

    // this is non-trivial, we need a manager object to handle this
    if (numSubminibatchesNeeded > 1)
        smbDispatcher.Init(net, learnableNodes, criterionNodes, evaluationNodes);

    // The following is a special feature only supported by the Kaldi2Reader for more efficient sequence training.
    // This attempts to compute the error signal for the whole utterance, which will
    // be fed to the neural network as features. Currently it is a workaround
    // for the two-forward-pass sequence and ctc training, which allows
    // processing more utterances at the same time.
    // TODO: move the two-forward-pass support out of the reader, make a first-class citizen.
    AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, featureNodes, inputMatrices);

    if (m_traceLevel > 0)
    {
        fprintf(stderr, "\n");
        LOGPRINTF(stderr, "Starting minibatch loop");
        if (useGradientAggregation)
        {
            fprintf(stderr, ", DataParallelSGD training (myRank = %d, numNodes = %d, numGradientBits = %d)",
                    (int) m_mpi->CurrentNodeRank(), (int) m_mpi->NumNodesInUse(), (int) m_numGradientBits[epochNumber]);

            if (m_bufferedAsyncGradientAggregation)
                fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED");
        }

        if (useAsyncGradientAggregation)
        {
            fprintf(stderr, ", DataParallelASGD training (myRank = %d, numNodes = %d, SamplesSyncToServer = %d)",
                (int)m_mpi->CurrentNodeRank(), (int)m_mpi->NumNodesInUse(), (int) m_nSyncSamplesPerWorker[epochNumber]);

            fprintf(stderr, ", Distributed Evaluation is DISABLED");

            if (m_isAsyncBufferEnabled)
                fprintf(stderr, ", BufferedAsyncGradientAggregation is ENABLED");
        }

        if (useDistributedMBReading)
            fprintf(stderr, ", distributed reading is ENABLED");

        if (numSubminibatchesNeeded > 1)
        {
            if (m_maxSamplesInRAM < SIZE_MAX)
                fprintf(stderr, ", with maximum %d samples in RAM", (int)m_maxSamplesInRAM);
            else
                fprintf(stderr, ", with %d subminibatch", (int)numSubminibatchesNeeded);
        }
        fprintf(stderr, ".\n");
    }

    Timer timer;
    timer.Start();

    // NOTE: the following two local matrices are not used in distGradAgg path
    // assume only one training criterion node for each epoch.
    // The criterion values are accumulated here over the minibatches (without having to pull them off the GPU).
    // For half, the cr and error nodes should be float nodes
    shared_ptr<CriterionAccumulatorBase> localEpochCriterionPtr = CriterionAccumulatorFactory::CreateCriterionAccumulator<ElemType>(
        criterionNodes, net->GetDeviceId());
    shared_ptr<CriterionAccumulatorBase> localEpochEvalErrorsPtr = CriterionAccumulatorFactory::CreateCriterionAccumulator<ElemType>(
        evaluationNodes, net->GetDeviceId(),
        {evaluationNodesWhichAccumulateResult.begin(), evaluationNodesWhichAccumulateResult.end()});
    CriterionAccumulatorBase& localEpochCriterion = *localEpochCriterionPtr;
    CriterionAccumulatorBase& localEpochEvalErrors = *localEpochEvalErrorsPtr;

    // --- MAIN MINIBATCH LOOP

    // for differential logging, we keep the previous criterion values around
    EpochCriterion         epochCriterionLastLogged  = epochCriterion;
    vector<EpochCriterion> epochEvalErrorsLastLogged = epochEvalErrors;

    EpochCriterion         tensorBoardEpochCriterionLastLogged = epochCriterion;
    vector<EpochCriterion> tensorBoardEpochEvalErrorsLastLogged = epochEvalErrors;

    // NOTE: For ResNet, the regularization in BatchNormalization should be disabled.
    if (m_disableRegInBatchNormalization) {
        let bnNodes = net->GetNodesWithType(L"BatchNormalization");
        for (auto &node : bnNodes)
        {
            let bnNode = dynamic_pointer_cast<BatchNormalizationNode<ElemType>>(node);
            bnNode->DisableRegInBatchNormalization();
        }
    }

    // In case adaptive minibatch/learning rates are used, the training can be limited by the maxNumberOfSamples.
    bool maxNumSamplesExceeded = false;
    size_t epochStartSample = 0;
    bool shouldCheckEarlyExit = (maxNumberOfSamples != SIZE_MAX);
    if (shouldCheckEarlyExit)
    {
        // SparsePC, LibSCV and DSS readers do not implement GetCurrentSamplePosition()
        // for those adaptive minibatch size is not supported, thus specifying adaptive 
        // minibatch for them will cause an error message.
        epochStartSample = trainSetDataReader->GetCurrentSamplePosition();
    }

    auto forwardPropRoots = evaluationNodes;
    forwardPropRoots.push_back(criterionNodes[0]);

    bool noMoreSamplesToProcess = false;
    bool isFirstMinibatch = true;
    for (;;)
    {
        auto profMinibatch = ProfilerTimeBegin();

        // get minibatch
        // TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers?
        size_t actualMBSize = 0;

        auto profGetMinibatch = ProfilerTimeBegin();
        bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork<ElemType>(*trainSetDataReader, net, criterionNodes[0],
                                                                                useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi);

        if (maxNumSamplesExceeded) // Dropping data.
            wasDataRead = false;

        if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) // in case of distributed reading, we do a few more loops until all ranks have completed
            break;                                                                // end of epoch

        // Note: If !wasDataRead then the data that GetMinibatchIntoNetwork() was supposed to fill in are undefined.
        // Must not touch them.

        if (!wasDataRead)
        {
            actualMBSize = 0; // (undefined if !wasDataRead)
            ProfilerEnable(false); // Profiler will be enabled at the beginning of the next epoch.
        }

        ProfilerTimeEnd(profGetMinibatch, profilerEvtMainGetMinibatch);
        auto profForwardBackward = ProfilerTimeBegin();

        nSamplesSinceLastModelSync += actualMBSize;

        // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
        // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
        // w.r.t. inputs to force evaluation in each minibatch
        MarkDropoutNodesEvalTimeStampAsOutdated(net, criterionNodes[0]);

        // node data was changed
        // TODO: move this to that function as well--just tired to pass everything as arguments
        // TODO: We should do this right after the GetMinibatch() call, since that's where these changed.
        //       Need to check whether that would cause unintended side effects.
        // TODO: original code did not call this for actualMBSize == 0
        ComputationNetwork::BumpEvalTimeStamp(featureNodes);
        ComputationNetwork::BumpEvalTimeStamp(labelNodes);

        if (actualMBSize > 0)
        {
            assert(wasDataRead);
#ifndef EVALDLL
            if (m_doGradientCheck && GradientCheck(net, criterionNodes, learnableNodes, 0) == false)
                LogicError("cannot pass gradient checker");
#endif
            // TODO: currently we only support one node for regularization
            if (m_needAdaptRegularization && m_adaptationRegType == AdaptationRegType::KL && refNode)
            {
                size_t actualMBSize2 = refNet->DetermineActualMBSizeFromFeatures();
                refNet->GetMBLayoutPtrOfNetwork()->CopyFrom(net->GetMBLayoutPtrOfNetwork()); // TODO: This is UNTESTED (before this was missing, seemingly inconsistently)

                if (actualMBSize2 != actualMBSize)
                    LogicError("TrainOneEpoch: refNet has different MB size than main net??");

                refNet->ForwardProp(refNode);
                Matrix<ElemType>::ScaleAndAdd((ElemType) m_adaptationRegWeight,
                                              dynamic_pointer_cast<ComputationNode<ElemType>>(refNode)->Value(),
                                              (ElemType)(1.0 - m_adaptationRegWeight),
                                              dynamic_pointer_cast<ComputationNode<ElemType>>(labelNodes[0])->Value());
            }

            // do forward and back propagation

            // We optionally break the minibatch into sub-minibatches.
            // This, when enabled, is used when a full minibatch does not fit into GPU RAM.
            size_t actualNumSubminibatches = numSubminibatchesNeeded <= 1 ? 1 : smbDispatcher.GetMinibatchIntoCache(*trainSetDataReader, *net, *inputMatrices, numSubminibatchesNeeded);
            for (size_t ismb = 0; ismb < actualNumSubminibatches; ismb++)
            {
                if (actualNumSubminibatches > 1)
                {
                    smbDispatcher.GetSubMinibatchToNet(ismb); // get sub-minibatch from full-size one
                    ComputationNetwork::BumpEvalTimeStamp(featureNodes);
                    ComputationNetwork::BumpEvalTimeStamp(labelNodes);
                }

                // ===========================================================
                // forward prop for evaluate eval nodes
                // ===========================================================

                // compute eval node first since when gradient is computed the forward function values
                // may be changed and need to be recomputed when gradient and function value share the same matrix
                net->ForwardProp(forwardPropRoots); // the bulk of this evaluation is reused in ComputeGradient() below

                // ===========================================================
                // backprop
                // ===========================================================

                if (learnRatePerSample > 0.01 * m_minLearnRate) // only compute gradient when learning rate is large enough
                    net->Backprop(criterionNodes[0]);

                // house-keeping for sub-minibatching
                if (actualNumSubminibatches > 1)
                    smbDispatcher.DoneWithCurrentSubMinibatch(ismb); // page state out
            }                                                        // end sub-minibatch loop
            if (actualNumSubminibatches > 1)
                smbDispatcher.DoneWithCurrentMinibatch();
        } // if (actualMBSize > 0)
        // WARNING: If actualMBSize == 0, then criterion nodes have NOT been updated, and contain garbage (last MB's) values.

        // In case of mini epochs (used for adaptive minibatch size and learning rate),
        // no more data should be processed by this worker.
        if (shouldCheckEarlyExit)
        {
            if (epochStartSample + maxNumberOfSamples < trainSetDataReader->GetCurrentSamplePosition())
                maxNumSamplesExceeded = true;
        }

        ProfilerTimeEnd(profForwardBackward, profilerEvtMainFB);
        auto profGradientAgg = ProfilerTimeBegin();

        // for momentum/clipping/regularization/etc., as well as for progress and statistics, we should only count frames that are not gaps
        // #samples according to the default dynamic axis, for use with criterion nodes that do not have an MBLayout
        size_t numSamplesWithLabelOfNetwork = wasDataRead ? net->GetNumSamplesWithLabelOfNetwork(actualMBSize) : 0; // (0 for empty MB)
        // Note: All accumulation into an EpochCriterion uses 'numSamplesWithLabelOfNetwork' as the generic,
        // fallback minibatch size. If that is 0, then nodes are considered containing zero samples,
        // independent of their actual content (which is considered outdated).

        // Sum of actualMBSize across all nodes when using parallel training
        // 'aggregate' here means across-worker aggregate for this one minibatch.
        size_t aggregateNumSamples = actualMBSize; // (0 for empty MB)
        size_t aggregateNumSamplesWithLabel = CriterionAccumulator<ElemType>::GetNumSamples(criterionNodes[0], numSamplesWithLabelOfNetwork); // (0 for empty MB)

        if (!useGradientAggregation)
        {
            // accumulate criterion values (objective, eval)
            assert(wasDataRead || numSamplesWithLabelOfNetwork == 0);
            // criteria are in Value()(0,0), we accumulate into another 1x1 Matrix (to avoid having to pull the values off the GPU)
            localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork);
            for (size_t i = 0; i < evaluationNodes.size(); i++)
                localEpochEvalErrors.Add(i, numSamplesWithLabelOfNetwork);
        }
        else
        {
            // distributed gradient aggregation
            if (learnParamsGradients.size() == 0)
            {
                // lazily form the list of smoothedGradients to exchange
                learnParamsGradients.reserve(learnableNodes.size());
                for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++)
                {
                    ComputationNodePtr node = dynamic_pointer_cast<ComputationNode<ElemType>>(*nodeIter);
                    if (node->IsParameterUpdateRequired())
                    {
                        Matrix<ElemType>* currParamsGradient = &(node->Gradient()); // TODO: we can use shared_ptrs now

                        // Sometimes, in parallel training, the current node may not get any samples to process
                        // In this case, the gradient matrix may not have been sized yet. If so, lets size it.
                        if (currParamsGradient->GetNumCols() == 0)
                        {
                            Matrix<ElemType>* currParamsValues = &(node->Value());
                            currParamsGradient->Resize(currParamsValues->GetNumRows(), currParamsValues->GetNumCols());
                        }

                        learnParamsGradients.push_back(currParamsGradient);
                    }
                }
            }

            // hoist the criterion into CPU space for all-reduce
            localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork);
            for (size_t i = 0; i < evaluationNodes.size(); i++)
                localEpochEvalErrors.Assign(i, numSamplesWithLabelOfNetwork);

            // copy all values to be aggregated into the header
            m_gradHeader->numSamples = aggregateNumSamples;
            m_gradHeader->criterion           = localEpochCriterion.GetCriterion(0).first;
            m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel
            assert(m_gradHeader->numSamplesWithLabel == aggregateNumSamplesWithLabel);
            for (size_t i = 0; i < evaluationNodes.size(); i++)
                m_gradHeader->evalErrors[i] = localEpochEvalErrors.GetCriterion(i);

            // aggregate
            m_gradHeader->numEvalNode = evaluationNodes.size(); // TODO: rename numEvalNode (plural)
            bool samplesProcessed = m_distGradAgg->AggregateGradients(learnParamsGradients, m_gradHeader.get(), isFirstMinibatch);
            noMoreSamplesToProcess = !samplesProcessed;

            // read out the header--now everything is aggregated
            aggregateNumSamples          = m_gradHeader->numSamples;
            aggregateNumSamplesWithLabel = m_gradHeader->numSamplesWithLabel;
            epochCriterion += EpochCriterion(m_gradHeader->criterion, m_gradHeader->numSamplesWithLabel);
            for (size_t i = 0; i < epochEvalErrors.size(); i++)
            {
                if (ContainsAccumulatedResult(evaluationNodes[i]))
                {
                    // We don't accumulate error in epoch criterion as this node has already accumulated error for
                    // all samples that passed through network in forward pass.
                    if (samplesProcessed)
                    {
                        epochEvalErrors[i] = m_gradHeader->evalErrors[i];
                    }
                    // else: no samples processed, no aggregation happened -> we do not want to override current value
                    // with 0.
                }
                else
                    epochEvalErrors[i] += m_gradHeader->evalErrors[i];
            }
        }

        ProfilerTimeEnd(profGradientAgg, profilerEvtMainGradient);
        auto profWeights = ProfilerTimeBegin();

        // update model parameters
        if ((aggregateNumSamples > 0) && (learnRatePerSample > m_minLearnRate * 0.01))
        {
#if 1       // BUGBUG: We must skip gaps in our momentum, clipping, regularization etc. criteria.
            // This will break test cases. So for now, we will only enable this for per-sample criteria.
            size_t numSamplesInMinibatch = aggregateNumSamples;
            if (criterionNodes[0]->HasMBLayout())
#endif
            numSamplesInMinibatch = aggregateNumSamplesWithLabel;
#if 0
            if (numSamplesInMinibatch != aggregateNumSamples)
                fprintf(stderr, "SGD: using true #samples %d instead of MB size %d\n", (int)numSamplesInMinibatch, (int)aggregateNumSamples);
#endif
            auto smoothedGradientIter = smoothedGradients.begin();
            auto smoothedCountIter = smoothedCounts.begin();
            for (auto nodeIter = learnableNodes.begin(); nodeIter != learnableNodes.end(); nodeIter++, smoothedGradientIter++, smoothedCountIter++)
            {
                ComputationNodeBasePtr node = *nodeIter;
                if (node->IsParameterUpdateRequired())
                {
#ifdef _DEBUG
                    bool hasNan = false;
                    if (std::is_same<ElemType, half>())
                    {
                        // Get metrix from compound metrix
                        auto compoundMatrixPtr = dynamic_pointer_cast<Matrix<float>> (*smoothedGradientIter);
                        if (compoundMatrixPtr)
                        {
                            size_t numCols = dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().GetNumCols();

                            auto smoothedGradient = compoundMatrixPtr->ColumnSlice(0, numCols);
                            hasNan = smoothedGradient.HasNan("TrainOneEpoch/UpdateWeights(): ");
                        }
                    }
                    else
                    {
                        auto smoothedGradient = dynamic_pointer_cast<Matrix<ElemType>> (*smoothedGradientIter);
                        hasNan = smoothedGradient && smoothedGradient->HasNan("TrainOneEpoch/UpdateWeights(): ");
                    }
                    if (hasNan)
                        LogicError("%ls %ls operation has NaNs in smoothedGradient.", node->NodeName().c_str(), node->OperationName().c_str());
#endif
                    double nodeDependentLearningRatePerSample = learnRatePerSample * node->GetLearningRateMultiplier();
                    double nodeDependentRegMultiplier = dynamic_pointer_cast<LearnableParameter<ElemType>>(node)->GetRegMultiplier();
                    double momentumPerSample = GetMomentumPerSample(epochNumber /*BUGBUG workaround:*/, net->GetMBLayoutPtrOfNetwork()->GetNumParallelSequences());
                    // TODO: Check why l2Factor is not applied to L1. Bug?
                    // BUGBUG (Issue #95): Access to net MBLayout can no longer be done if we have multiple input layouts
                    UpdateWeights(dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value(),
                                  dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Gradient(),
                                  *smoothedGradientIter, *smoothedCountIter,
                                  nodeDependentLearningRatePerSample, momentumPerSample,
                                  numSamplesInMinibatch,
                                  m_L2RegWeight * nodeDependentRegMultiplier, m_L1RegWeight * nodeDependentRegMultiplier,
                                  m_needAveMultiplier, m_useNesterovMomentum);
                    node->BumpEvalTimeStamp();
#ifdef _DEBUG
                    if (dynamic_pointer_cast<ComputationNode<ElemType>>(node)->Value().HasNan("TrainOneEpoch/UpdateWeights(): "))
                        LogicError("%ls %ls operation has NaNs in functionValues after parameter update.", node->NodeName().c_str(), node->OperationName().c_str());
#endif
                }
            }
        }


        // aggregation by model averaging or block momentum 
        if (useModelAggregation)
        {
            if (nSamplesSinceLastModelSync >= blockSizePerWorker)
            {
                bool synced = m_pMASGDHelper->OnArrivingAtSyncPoint(learnableNodes, smoothedGradients, nSamplesSinceLastModelSync);
                if (synced)
                {
                    nSamplesSinceLastModelSync = 0;
                }
            }
            // prepare break condition
            if (useDistributedMBReading)
            {
                noMoreSamplesToProcess = !wasDataRead;
            }
        }

        // using parameter server for parameter update
        if (useAsyncGradientAggregation && m_mpi->NumNodesInUse() > 1)
        {
            // Determine if any samples were processed across any of the ranks
            if (useDistributedMBReading)
            {
                noMoreSamplesToProcess = !wasDataRead;
            }

            if (nSamplesSinceLastModelSync >= m_nSyncSamplesPerWorker[epochNumber])
            {
                m_pASGDHelper->PushAndPullModel(learnableNodes, nSamplesSinceLastModelSync);
                nSamplesSinceLastModelSync = 0;
            } 
        }


        ProfilerTimeEnd(profWeights, profilerEvtMainWeights);
        auto profPost = ProfilerTimeBegin();

        timer.Stop();

        numMBsRun++;
        totalTimeInMBs += timer.ElapsedSeconds();

        bool progressPrintNeeded = numMBsRun <= m_firstMBsToShowResult || (m_numMBsToShowResult && (numMBsRun % m_numMBsToShowResult == 0));
        bool tensorBoardWriteNeeded = tensorBoardWriter && m_tensorBoardNumMBsToLogResult && 
            ((totalMBsSeenBefore + numMBsRun) % m_tensorBoardNumMBsToLogResult == 0);

        // Get the epoch Values updated. Take care to fetch values from GPU only when this is really needed.
        if ((progressPrintNeeded || tensorBoardWriteNeeded) && !useGradientAggregation)
        {
            // if no aggregation, we directly get the values from the minibatch accumulators
            timer.Restart();
            epochCriterion = localEpochCriterion.GetCriterion(0);
            for (size_t i = 0; i < epochEvalErrors.size(); i++)
                epochEvalErrors[i] = localEpochEvalErrors.GetCriterion(i);
            timer.Stop();

            // Add the last trailing compute
            totalTimeInMBs += timer.ElapsedSeconds();
        }

        // log
        // This shows the criterion since last logged.
        if (progressPrintNeeded)
        {
            // epochCriterion aggregates over entire epoch, but we only show difference to last time we logged
            EpochCriterion epochCriterionSinceLastLogged = epochCriterion - epochCriterionLastLogged;
            let trainLossSinceLastLogged    =      epochCriterionSinceLastLogged.Average(); // TODO: Check whether old trainSamplesSinceLastLogged matches this ^^ difference
            let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second;

            // determine progress in percent
            int mbProgNumPrecision = 2;
            double mbProg = 0.0;

            // Skip epoch size computation if we aren't asked to and epoch is not the starting epoch
            bool skipComputeEpochSize = epochNumber > startEpoch || epochSize != requestDataSize;

            if (skipComputeEpochSize)
            {
                if (m_maxComputedEpochSize != 0)
                {
                    double numMBPerEpoch = (double)m_maxComputedEpochSize / (double)tunedMBSize;
                    mbProg = (double)numMBsRun / numMBPerEpoch;
                    mbProgNumPrecision = (int)ceil(log10(numMBPerEpoch / (double)(numMBsRun - numMBsRunSinceLastLogged)));
                    mbProgNumPrecision = max(mbProgNumPrecision - 2, 2);
                }
            }
            else // estimate epoch size
                m_maxComputedEpochSize = numMBsRun * trainSamplesSinceLastLogged / (numMBsRun - numMBsRunSinceLastLogged);

            // progress tracing for compute cluster management
            let wasProgressPrinted = ProgressTracing::TraceProgressPercentage(epochNumber, mbProg, false);

            // progress tracing for regular log
            if (m_traceLevel > 0)
            {
                PREPENDTS(stderr);
                fprintf(stderr, "%s Epoch[%2d of %d]-Minibatch[%4d-%4d",
                        prefixMsg.c_str(), epochNumber + 1, (int)m_maxEpochs,
                        (int)(numMBsRunSinceLastLogged + 1), numMBsRun);
                if (skipComputeEpochSize)
                    fprintf(stderr, (", %2." + to_string(mbProgNumPrecision) + "f%%").c_str(), mbProg * 100); // --TODO: use a * format?
                fprintf(stderr, "]: ");
                epochCriterionSinceLastLogged.LogCriterion(criterionNodes[0]->NodeName());
                for (size_t i = 0; i < epochEvalErrors.size(); i++)
                {
                    const std::wstring& nodeName = evaluationNodes[i]->NodeName();
                    if (ContainsAccumulatedResult(evaluationNodes[i]))
                    {
                        // For aggregation nodes, we don't report per minibatch error. These nodes calculate
                        // aggregated error for all samples that passed through network, instead of calculating per
                        // sample error. Aggregated error for all samples will be reported for these nodes.
                        epochEvalErrors[i].LogCriterion(nodeName);
                    }
                    else
                    {
                        // Report per minibatch error.
                        (epochEvalErrors[i] - epochEvalErrorsLastLogged[i]).LogCriterion(nodeName);
                    }
                }

                fprintf(stderr, ("time = " + GeneratePaddedFloatOrExpFormat(0, 4, totalTimeInMBs) + "s; samplesPerSecond = %.1f\n").c_str(),
                        totalTimeInMBs, trainSamplesSinceLastLogged / totalTimeInMBs);
            }

            // progress tracing for compute cluster management
            if (wasProgressPrinted)
                ProgressTracing::TraceTrainLoss(trainLossSinceLastLogged);

            if (m_traceLevel > 0)
                fflush(stderr);

            if (epochCriterion.IsNan())
                RuntimeError("The training criterion is not a number (NAN).");

            // reset statistics for differential logging
            epochCriterionLastLogged  = epochCriterion;
            epochEvalErrorsLastLogged = epochEvalErrors;
            numMBsRunSinceLastLogged = numMBsRun;
            for (size_t i = 0; i < epochEvalErrors.size(); i++)
            {
                if (ContainsAccumulatedResult(evaluationNodes[i]))
                {
                    // For nodes that accumulate result we report accumulated error for all samples that passed through
                    // network so far, instead of per minibatch error. So, we reset last logged error here.
                    epochEvalErrorsLastLogged[i] = EpochCriterion(0);
                }
            }

            totalTimeInMBs = 0;
        }

        // Log progress to TensorBoard.
        // Only do this if TensorBoard logging is enabled, the current worker has rank 0, and it is time to write
        // the log (as controlled by tensorBoardNumMBsToLogResult).
        if (tensorBoardWriteNeeded)
        {
            // epochCriterion aggregates over entire epoch, but we only show difference to last time we logged
            EpochCriterion epochCriterionSinceLastLogged = epochCriterion - tensorBoardEpochCriterionLastLogged;
            double trainLossSinceLastLogged = epochCriterionSinceLastLogged.Average();

            // numMBsRun is specific to the current epoch and is reset for each epoch.
            // We cannot use it if we want to view progress of loss/eval since the start of training. 
            // Instead, we use a total number of minibatches run from the start of training as a step.
            const size_t step = totalMBsSeenBefore + (size_t)numMBsRun;
            tensorBoardWriter->WriteValue(L"minibatch/" + criterionNodes[0]->NodeName(), (float)trainLossSinceLastLogged, step);
            for (size_t i = 0; i < epochEvalErrors.size(); i++)
            {
                const std::wstring& nodeName = evaluationNodes[i]->NodeName();
                // For aggregation nodes, we don't report per minibatch error. These nodes calculate
                // aggregated error for all samples that passed through network, instead of calculating per
                // sample error. Aggregated error for all samples will be reported for these nodes.
                const EpochCriterion& evalErrorSinceLastLogged = ContainsAccumulatedResult(evaluationNodes[i])
                    ? epochEvalErrors[i]
                    : epochEvalErrors[i] - tensorBoardEpochEvalErrorsLastLogged[i];
                tensorBoardWriter->WriteValue(L"minibatch/" + nodeName, (float)evalErrorSinceLastLogged.Average(), step);
            }

            tensorBoardWriter->Flush();

            // reset statistics for differential logging
            tensorBoardEpochCriterionLastLogged = epochCriterion;
            tensorBoardEpochEvalErrorsLastLogged = epochEvalErrors;
            for (size_t i = 0; i < epochEvalErrors.size(); i++)
            {
                if (ContainsAccumulatedResult(evaluationNodes[i]))
                {
                    // For nodes that accumulate result we report accumulated error for all samples that passed through
                    // network so far, instead of per minibatch error. So, we reset last logged error here.
                    tensorBoardEpochEvalErrorsLastLogged[i] = EpochCriterion(0);
                }
            }
        }

        timer.Restart();
        totalEpochSamples += aggregateNumSamplesWithLabel;

        // call DataEnd function
        // This signals something from SGD to the reader.
        // DataEnd does reader specific process if sentence ending is reached
        trainSetDataReader->DataEnd();

        // Attempts to compute the error signal for the whole utterance, which will
        // be fed to the neural network as features. Currently it is a workaround
        // for the two-forward-pass sequence and ctc training, which allows
        // processing more utterances at the same time. Only used in Kaldi2Reader.
        // TODO: move the two-forward-pass support out of the reader.
        AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, featureNodes, inputMatrices);

        profiler.NextSample();
        isFirstMinibatch = false;

        ProfilerTimeEnd(profPost, profilerEvtMainPost);
        ProfilerTimeEnd(profMinibatch, profilerEvtMainMinibatch);
    }

    // --- END MAIN MINIBATCH LOOP

    if (useModelAggregation )
    {
        m_pMASGDHelper->OnEpochEnd(learnableNodes, smoothedGradients, nSamplesSinceLastModelSync);
        nSamplesSinceLastModelSync = 0;
    }

    if (useAsyncGradientAggregation && (m_mpi->NumNodesInUse() > 1))
    {
        m_pASGDHelper->PushAndPullModel(learnableNodes, nSamplesSinceLastModelSync);
        nSamplesSinceLastModelSync = 0;
    }

    // hoist the accumulated criterion value from GPU side to our 'out'  variables
    // (unless we useGradientAggregation, in which case they are accumulated in the 'out' variables directly)
    if (!useGradientAggregation)
    {
        epochCriterion = localEpochCriterion.GetCriterion(0);
        for (size_t i = 0; i < epochEvalErrors.size(); i++)
            epochEvalErrors[i] = localEpochEvalErrors.GetCriterion(i);
    }

    // in case of model averaging, do one more final aggregation of criteria
    if (useModelAggregation && (m_mpi->NumNodesInUse() > 1))
    {
        // 1. total epoch samples processed by all workers
        size_t totalEpochSamplesOfAllWorkers = totalEpochSamples;
        m_mpi->AllReduce(&totalEpochSamplesOfAllWorkers, 1);

        // get criteria for this worker
        assert(!useGradientAggregation); // (otherwise the data would not be in localEpochCriterion)
        epochCriterion = localEpochCriterion.GetCriterion(0);
        for (size_t i = 0; i < epochEvalErrors.size(); i++)
            epochEvalErrors[i] = localEpochEvalErrors.GetCriterion(i);

        // all-reduce epochCriterion and epochEvalErrors over nodes
        m_mpi->AllReduce(&epochCriterion.first,  1);
        m_mpi->AllReduce(&epochCriterion.second, 1);
        // to transfer the eval vectors, we must pull them apart into STL objects and exchange them separately
        // TODO: merge with training criteria
        vector<double> numer(epochEvalErrors.size());
        vector<size_t> denom(epochEvalErrors.size());
        for (size_t i = 0; i < epochEvalErrors.size(); i++)
        {
            numer[i] = epochEvalErrors[i].first;
            denom[i] = epochEvalErrors[i].second;
        }
        m_mpi->AllReduce(numer);
        m_mpi->AllReduce(denom);
        for (size_t i = 0; i < epochEvalErrors.size(); i++)
            epochEvalErrors[i] = EpochCriterion(numer[i], denom[i]);

        totalEpochSamples = totalEpochSamplesOfAllWorkers;
    }

    if (useGradientAggregation && !evaluationNodesWhichAccumulateResult.empty())
    {
        // Each worker contains accumulated values for part of the data set, we have to aggregate accumulated values
        // and recalculate evaluation errors based on accumulators.
        AggregateAccumulatorValuesAndUpdateEpochEvaluation<ElemType>(
            net, evaluationNodesWhichAccumulateResult, m_gradHeader, m_mpi, epochEvalErrors, evaluationNodes,
            localEpochEvalErrors, ContainsAccumulatedResult, m_packThresholdSizeInBytes);
    }

    return numMBsRun;
}