def performSearch()

in tools/SeeDot/seedot/main.py [0:0]


    def performSearch(self):
        start, end = config.maxScaleRange
        lastStageAcc = -1

        fixedPointCounter = 0
        while True:
            # STAGE I exploration.
            print("Stage I Exploration: Determining scale for input \'X\'...")
            fixedPointCounter += 1
            if config.fixedPointVbwIteration:
                Util.getLogger().debug("Will compile until conversion to fixed point. Iteration %d"%fixedPointCounter)
            highestValidScale = start
            firstCompileSuccess = False
            # Bar longer than actually required
            stage_1_bar = tqdm(total=(2 * abs(start - end) + 2), mininterval=0, miniters=1, leave=True)
            while firstCompileSuccess == False:
                if highestValidScale == end:
                    Util.getLogger().error("Compilation not possible for any scale factor of variable \'X\'. Aborting code!")
                    return False

                # Refactor and remove this try/catch block in the future.
                try:
                    firstCompileSuccess = self.partialCompile(config.Encoding.fixed, config.Target.x86, highestValidScale, True, None, 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
                except:
                    firstCompileSuccess = False

                if firstCompileSuccess:
                    stage_1_bar.update(highestValidScale - end + 1)
                    break
                highestValidScale -= 1
                stage_1_bar.update(1)
            
            lowestValidScale = end + 1
            firstCompileSuccess = False
            while firstCompileSuccess == False:
                try:
                    firstCompileSuccess = self.partialCompile(config.Encoding.fixed, config.Target.x86, lowestValidScale, True, None, 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
                except:
                    firstCompileSuccess = False
                if firstCompileSuccess:
                    stage_1_bar.update(start - lowestValidScale + 2)
                    break
                lowestValidScale += 1
                stage_1_bar.update(1)
            stage_1_bar.close()

            # Ignored.
            self.partialCompile(config.Encoding.fixed, config.Target.x86, lowestValidScale, True, None, -1, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))

            print("Stage II Exploration: Determining scale for all non-\'X\' variables...")
            # The iterator logic is as follows:
            # Search begins when the first valid scaling factor is found (runOnce returns True).
            # Search ends when the execution fails on a particular scaling factor (runOnce returns False).
            # This is the window where valid scaling factors exist and we
            # select the one with the best accuracy.
            numCodes = highestValidScale - lowestValidScale + 1
            codeId = 0
            codeIdToScaleFactorMap = {}
            for i in tqdm(range(highestValidScale, lowestValidScale - 1, -1)):
                if config.ddsEnabled:
                    Util.getLogger().debug("Testing with DDS and scale of X as " + str(i) + "\n")
                else:
                    Util.getLogger().debug("Testing with max scale factor of " + str(i) + "\n")

                codeId += 1
                try:
                    compiled = self.partialCompile(
                        config.Encoding.fixed, config.Target.x86, i, False, codeId, -1 if codeId != numCodes else codeId, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
                except: # If some code in the middle fails to compile.
                    codeId -=1
                    continue
                if compiled == False:
                    return False
                codeIdToScaleFactorMap[codeId] = i

            print("Stage II Code Run Started...")
            res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, codeIdToScaleFactorMap)
            print("Stage II Code Run Completed!\n")
            if exit == True or res == False:
                return False

            Util.getLogger().info("\nSearch completed\n")
            Util.getLogger().info("----------------------------------------------\n\n")
            Util.getLogger().info("Best performing scaling factors with accuracy, disagreement, reduced disagreement:")

            self.sf = self.getBestScale()
            if self.accuracy[self.sf][0] != lastStageAcc:
                lastStageAcc = self.accuracy[self.sf][0]
            elif config.fixedPointVbwIteration:
                Util.getLogger().info("No difference in iteration %d Stage 2 and iteration %d Stage 1. Stopping search\n"%(fixedPointCounter-1, fixedPointCounter))
                break

            if config.vbwEnabled:
                # Stage III exploration.
                print("Stage III Exploration: Demoting variables one at a time...")

                assert config.ddsEnabled, "Currently VBW on maxscale not supported"
                if config.wordLength != 16:
                    assert False, "VBW mode only supported if native bitwidth is 16"
                Util.getLogger().debug("Scales computed in native bitwidth. Starting exploration over other bitwidths.")

                # We attempt to demote all possible variables in the code. We try out multiple different scales
                # (controlled by config.offsetsPerDemotedVariable) for each demoted variable. When a variable is
                # demoted, it is assigned a scale given by :
                # demoted Scale = self.allScales[var] + 8 - offset

                attemptToDemote = [var for var in self.variableToBitwidthMap if (var[-3:] != "val" and var not in self.demotedVarsList)]
                numCodes = config.offsetsPerDemotedVariable * len(attemptToDemote) + ((9 - config.offsetsPerDemotedVariable) if 'X' in attemptToDemote else 0)
                # 9 offsets tried for X while 'offsetsPerDemotedVariable' tried for other variables.

                # We approximately club batchSize number of codes in one generated C++ code, so that one generated code does
                # not become too large.
                batchSize = int(np.ceil(50 / np.ceil(len(attemptToDemote) / 50)))
                redBatchSize = np.max((batchSize, 16)) / config.offsetsPerDemotedVariable

                totalSize = len(attemptToDemote)
                numBatches = int(np.ceil(totalSize / redBatchSize))

                self.varDemoteDetails = []
                for i in tqdm(range(numBatches)):
                    Util.getLogger().info("=====\nBatch %i out of %d\n=====\n" %(i + 1, numBatches))

                    firstVarIndex = (totalSize * i) // numBatches
                    lastVarIndex = (totalSize * (i + 1)) // numBatches
                    demoteBatch = [attemptToDemote[i] for i in range(firstVarIndex, lastVarIndex)]
                    numCodes = config.offsetsPerDemotedVariable * len(demoteBatch) + ((9 - config.offsetsPerDemotedVariable) if 'X' in demoteBatch else 0)
                    # 9 offsets tried for X while 'config.offsetsPerDemotedVariable' tried for other variables.

                    self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, True, None, -1 if len(demoteBatch) > 0 else 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
                    codeId = 0
                    contentToCodeIdMap = {}

                    for demoteVar in demoteBatch:
                        # For each variable being demoted, we populate some variables containing information regarding demoted variable.
                        newbitwidths = dict(self.variableToBitwidthMap)
                        newbitwidths[demoteVar] = config.wordLength // 2
                        if demoteVar + "val" in newbitwidths:
                            newbitwidths[demoteVar + "val"] = config.wordLength // 2
                        for alreadyDemotedVars in self.demotedVarsList: # In subsequent iterations during fixed point compilation, this variable will have the variables demoted during the previous runs.
                            newbitwidths[alreadyDemotedVars] = config.wordLength // 2
                        demotedVarsList = [i for i in newbitwidths.keys() if newbitwidths[i] != config.wordLength]
                        demotedVarsOffsets = {}
                        for key in self.demotedVarsList:
                            demotedVarsOffsets[key] = self.demotedVarsOffsets[key]

                        contentToCodeIdMap[tuple(demotedVarsList)] = {}
                        # We try out multiple offsets for each variable to find best scale assignment for each variable.
                        for demOffset in (range(0, -config.offsetsPerDemotedVariable, -1) if demoteVar != 'X' else range(0, -9, -1)):
                            codeId += 1
                            for k in demotedVarsList:
                                if k not in self.demotedVarsList:
                                    demotedVarsOffsets[k] = demOffset
                            contentToCodeIdMap[tuple(demotedVarsList)][demOffset] = codeId
                            compiled = self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, False, codeId, -1 if codeId != numCodes else codeId, dict(newbitwidths), list(demotedVarsList), dict(demotedVarsOffsets))
                            if compiled == False:
                                Util.getLogger().error("Variable bitwidth exploration resulted in a compilation error\n")
                                return False

                    res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, None, contentToCodeIdMap)
                
                print("Stage IV Exploration: Cumulatively demoting variables...")
                # Stage IV exploration.
                # Again, we compute only a limited number of inference codes per generated C++ so as to not bloat up the memory usage of the compiler.
                redBatchSize *= config.offsetsPerDemotedVariable
                totalSize = len(self.varDemoteDetails)
                numBatches = int(np.ceil(totalSize / redBatchSize))

                sortedVars1 = []
                sortedVars2 = []
                for ((demoteVars, offset), _) in self.varDemoteDetails:
                    variableInMap = False
                    for demoteVar in demoteVars:
                        if demoteVar in self.varSizes:
                            variableInMap = True
                            if self.varSizes[demoteVar] >= Util.Config.largeVariableLimit:
                                sortedVars1.append((demoteVars, offset))
                                break
                            else:
                                sortedVars2.append((demoteVars, offset))
                                break
                    if not variableInMap:
                        sortedVars2.append((demoteVars, offset))

                sortedVars = sortedVars1 + sortedVars2

                self.varDemoteDetails = []
                demotedVarsOffsets = dict(self.demotedVarsOffsets)
                demotedVarsList = list(self.demotedVarsList)
                demotedVarsListToOffsets = {}

                # Knowing the accuracy when each single variable is demoted to 8-bits one at a time, we proceed to cumulatively
                # demoting all of them one after the other ensuring accuracy of target code does not fall below a threshold. The
                # following for loop controls generation of inference codes.
                for i in tqdm(range(numBatches)):
                    Util.getLogger().info("=====\nBatch %i out of %d\n=====\n" %(i + 1, numBatches))

                    firstVarIndex = (totalSize * i) // numBatches
                    lastVarIndex = (totalSize * (i+1)) // numBatches
                    demoteBatch = [sortedVars[i] for i in range(firstVarIndex, lastVarIndex)]

                    self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, True, None, -1 if len(attemptToDemote) > 0 else 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
                    contentToCodeIdMap = {}
                    codeId = 0
                    numCodes = len(demoteBatch)
                    for (demoteVars, offset) in demoteBatch:
                        newbitwidths = dict(self.variableToBitwidthMap)
                        for var in demoteVars:
                            if var not in self.demotedVarsList:
                                newbitwidths[var] = config.wordLength // 2
                                demotedVarsOffsets[var] = offset
                            if var not in demotedVarsList:
                                demotedVarsList.append(var)
                        codeId += 1
                        contentToCodeIdMap[tuple(demotedVarsList)] = {}
                        contentToCodeIdMap[tuple(demotedVarsList)][offset] = codeId
                        demotedVarsListToOffsets[tuple(demotedVarsList)] = dict(demotedVarsOffsets)
                        compiled = self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, False, codeId, -1 if codeId != numCodes else codeId, dict(newbitwidths), list(demotedVarsList), dict(demotedVarsOffsets))
                        if compiled == False:
                            Util.getLogger().error("Variable bitwidth exploration resulted in another compilation error\n")
                            return False

                    res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, None, contentToCodeIdMap, True)

                if exit == True or res == False:
                    return False

                # The following for loop controls how many variables are actually demoted in the final output code, which has
                # as many variables as possible in 8-bits, while ensuring accuracy drop compared to floating point is reasonable:
                okToDemote = ()
                acceptedAcc = lastStageAcc
                for ((demotedVars, _), metrics) in self.varDemoteDetails:
                    acc = metrics[0]
                    if self.problemType == config.ProblemType.classification and (self.flAccuracy - acc) > config.permittedClassificationAccuracyLoss:
                        break
                    elif self.problemType == config.ProblemType.regression and acc > config.permittedRegressionNumericalLossMargin:
                        break
                    else:
                        okToDemote = demotedVars
                        acceptedAcc = acc

                self.demotedVarsList = [i for i in okToDemote] + [i for i in self.demotedVarsList]
                self.demotedVarsOffsets.update(demotedVarsListToOffsets.get(okToDemote, {}))

                if acceptedAcc != lastStageAcc:
                    lastStageAcc = acceptedAcc
                else:
                    Util.getLogger().warning("No difference in iteration %d's stages 1 & 2. Stopping search."%fixedPointCounter)
                    break

            if not config.vbwEnabled or not config.fixedPointVbwIteration:
                break

        return True