def train()

in tf/edgeml_tf/trainer/bonsaiTrainer.py [0:0]


    def train(self, batchSize, totalEpochs, sess,
              Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir):
        '''
        The Dense - IHT - Sparse Retrain Routine for Bonsai Training
        '''
        resultFile = open(dataDir + '/TFBonsaiResults.txt', 'a+')
        numIters = Xtrain.shape[0] / batchSize

        totalBatches = numIters * totalEpochs

        bonsaiObjSigmaI = 1

        counter = 0
        if self.bonsaiObj.numClasses > 2:
            trimlevel = 15
        else:
            trimlevel = 5
        ihtDone = 0
        if (self.bonsaiObj.isRegression is True):
            maxTestAcc = 100000007
        else:
            maxTestAcc = -10000
        if self.isDenseTraining is True:
            ihtDone = 1
            bonsaiObjSigmaI = 1
            itersInPhase = 0

        header = '*' * 20
        for i in range(totalEpochs):
            print("\nEpoch Number: " + str(i), file=self.outFile)

            '''
            trainAcc -> For Regression, it is 'Mean Absolute Error'.
            trainAcc -> For Classification, it is 'Accuracy'.
            '''
            trainAcc = 0.0
            trainLoss = 0.0

            numIters = int(numIters)
            for j in range(numIters):

                if counter == 0:
                    msg = " Dense Training Phase Started "
                    print("\n%s%s%s\n" %
                          (header, msg, header), file=self.outFile)

                # Updating the indicator sigma
                if ((counter == 0) or (counter == int(totalBatches / 3.0)) or
                        (counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False):
                    bonsaiObjSigmaI = 1
                    itersInPhase = 0

                elif (itersInPhase % 100 == 0):
                    indices = np.random.choice(Xtrain.shape[0], 100)
                    batchX = Xtrain[indices, :]
                    batchY = Ytrain[indices, :]
                    batchY = np.reshape(
                        batchY, [-1, self.bonsaiObj.numClasses])

                    _feed_dict = {self.X: batchX}
                    Xcapeval = self.X_.eval(feed_dict=_feed_dict)
                    Teval = self.bonsaiObj.T.eval()

                    sum_tr = 0.0
                    for k in range(0, self.bonsaiObj.internalNodes):
                        sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))

                    if(self.bonsaiObj.internalNodes > 0):
                        sum_tr /= (100 * self.bonsaiObj.internalNodes)
                        sum_tr = 0.1 / sum_tr
                    else:
                        sum_tr = 0.1
                    sum_tr = min(
                        1000, sum_tr * (2**(float(itersInPhase) /
                                            (float(totalBatches) / 30.0))))

                    bonsaiObjSigmaI = sum_tr

                itersInPhase += 1
                batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
                batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
                batchY = np.reshape(
                    batchY, [-1, self.bonsaiObj.numClasses])

                if self.bonsaiObj.numClasses > 2:
                    if self.useMCHLoss is True:
                        _feed_dict = {self.X: batchX, self.Y: batchY,
                                      self.batch_th: batchY.shape[0],
                                      self.sigmaI: bonsaiObjSigmaI}
                    else:
                        _feed_dict = {self.X: batchX, self.Y: batchY,
                                      self.sigmaI: bonsaiObjSigmaI}
                else:
                    _feed_dict = {self.X: batchX, self.Y: batchY,
                                  self.sigmaI: bonsaiObjSigmaI}

                # Mini-batch training
                _, batchLoss, batchAcc = sess.run(
                    [self.trainStep, self.loss, self.accuracy],
                    feed_dict=_feed_dict)

                # Classification.
                if (self.bonsaiObj.isRegression is False):
                    trainAcc += batchAcc
                    trainLoss += batchLoss
                # Regression.
                else:
                    trainAcc += np.mean(batchAcc)
                    trainLoss += np.mean(batchLoss)

                # Training routine involving IHT and sparse retraining
                if (counter >= int(totalBatches / 3.0) and
                    (counter < int(2 * totalBatches / 3.0)) and
                    counter % trimlevel == 0 and
                        self.isDenseTraining is False):
                    self.runHardThrsd(sess)
                    if ihtDone == 0:
                        msg = " IHT Phase Started "
                        print("\n%s%s%s\n" %
                              (header, msg, header), file=self.outFile)
                    ihtDone = 1
                elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and
                       (counter < int(2 * totalBatches / 3.0)) and
                       counter % trimlevel != 0 and
                       self.isDenseTraining is False) or
                        (counter >= int(2 * totalBatches / 3.0) and
                            self.isDenseTraining is False)):
                    self.runSparseTraining(sess)
                    if counter == int(2 * totalBatches / 3.0):
                        msg = " Sparse Retraining Phase Started "
                        print("\n%s%s%s\n" %
                              (header, msg, header), file=self.outFile)
                counter += 1
            try:
                if (self.bonsaiObj.isRegression is True):
                    print("\nRegression Train Loss: " + str(trainLoss / numIters) +
                          "\nTraining MAE (Regression): " +
                          str(trainAcc / numIters),
                          file=self.outFile)
                else:
                    print("\nClassification Train Loss: " + str(trainLoss / numIters) +
                          "\nTraining accuracy (Classification): " +
                          str(trainAcc / numIters),
                          file=self.outFile)
            except:
                continue

            oldSigmaI = bonsaiObjSigmaI
            bonsaiObjSigmaI = 1e9

            if self.bonsaiObj.numClasses > 2:
                if self.useMCHLoss is True:
                    _feed_dict = {self.X: Xtest, self.Y: Ytest,
                                  self.batch_th: Ytest.shape[0],
                                  self.sigmaI: bonsaiObjSigmaI}
                else:
                    _feed_dict = {self.X: Xtest, self.Y: Ytest,
                                  self.sigmaI: bonsaiObjSigmaI}
            else:
                _feed_dict = {self.X: Xtest, self.Y: Ytest,
                              self.sigmaI: bonsaiObjSigmaI}

            # This helps in direct testing instead of extracting the model out

            testAcc, testLoss, regTestLoss, pred = sess.run(
                [self.accuracy, self.loss, self.regLoss, self.prediction], feed_dict=_feed_dict)

            if ihtDone == 0:
                if (self.bonsaiObj.isRegression is False):
                    maxTestAcc = -10000
                    maxTestAccEpoch = i
                elif (self.bonsaiObj.isRegression is True):
                    maxTestAcc = testAcc
                    maxTestAccEpoch = i

            else:
                if (self.bonsaiObj.isRegression is False):
                    if maxTestAcc <= testAcc:
                        maxTestAccEpoch = i
                        maxTestAcc = testAcc
                        self.saveParams(currDir)
                        self.saveParamsForSeeDot(currDir)
                elif (self.bonsaiObj.isRegression is True):
                    print("Minimum Training MAE : ", np.mean(maxTestAcc))
                    if maxTestAcc >= testAcc:
                        # For regression , we're more interested in the minimum
                        # MAE.
                        maxTestAccEpoch = i
                        maxTestAcc = testAcc
                        self.saveParams(currDir)
                        self.saveParamsForSeeDot(currDir)

            if (self.bonsaiObj.isRegression is True):
                print("Testing MAE %g" % np.mean(testAcc), file=self.outFile)
            else:
                print("Test accuracy %g" % np.mean(testAcc), file=self.outFile)

            if (self.bonsaiObj.isRegression is True):
                testAcc = np.mean(testAcc)
            else:
                testAcc = testAcc
                maxTestAcc = maxTestAcc

            print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) +
                  " + " + str(regTestLoss) + " = " + str(testLoss) + "\n",
                  file=self.outFile)
            self.outFile.flush()

            bonsaiObjSigmaI = oldSigmaI

        # sigmaI has to be set to infinity to ensure
        # only a single path is used in inference
        bonsaiObjSigmaI = 1e9
        print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " +
              str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
              str(self.getModelSize()[2]) + "\n", file=self.outFile)

        if (self.bonsaiObj.isRegression is True):
            maxTestAcc = np.mean(maxTestAcc)

        if (self.bonsaiObj.isRegression is True):
            print("For Regression, Minimum MAE at compressed" +
                  " model size(including early stopping): " +
                  str(maxTestAcc) + " at Epoch: " +
                  str(maxTestAccEpoch + 1) + "\nFinal Test" +
                  " MAE: " + str(testAcc), file=self.outFile)

            resultFile.write("MinTestMAE: " + str(maxTestAcc) +
                             " at Epoch(totalEpochs): " +
                             str(maxTestAccEpoch + 1) +
                             "(" + str(totalEpochs) + ")" + " ModelSize: " +
                             str(float(self.getModelSize()[1]) / 1024.0) +
                             " KB hasSparse: " + str(self.getModelSize()[2]) +
                             " Param Directory: " +
                             str(os.path.abspath(currDir)) + "\n")

        elif (self.bonsaiObj.isRegression is False):
            print("For Classification, Maximum Test accuracy at compressed" +
                  " model size(including early stopping): " +
                  str(maxTestAcc) + " at Epoch: " +
                  str(maxTestAccEpoch + 1) + "\nFinal Test" +
                  " Accuracy: " + str(testAcc), file=self.outFile)

            resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
                             " at Epoch(totalEpochs): " +
                             str(maxTestAccEpoch + 1) +
                             "(" + str(totalEpochs) + ")" + " ModelSize: " +
                             str(float(self.getModelSize()[1]) / 1024.0) +
                             " KB hasSparse: " + str(self.getModelSize()[2]) +
                             " Param Directory: " +
                             str(os.path.abspath(currDir)) + "\n")
        print("The Model Directory: " + currDir + "\n")

        resultFile.close()
        self.outFile.flush()

        if self.outFile is not sys.stdout:
            self.outFile.close()