in tools/SeeDot/seedot/main.py [0:0]
def performSearch(self):
start, end = config.maxScaleRange
lastStageAcc = -1
fixedPointCounter = 0
while True:
# STAGE I exploration.
print("Stage I Exploration: Determining scale for input \'X\'...")
fixedPointCounter += 1
if config.fixedPointVbwIteration:
Util.getLogger().debug("Will compile until conversion to fixed point. Iteration %d"%fixedPointCounter)
highestValidScale = start
firstCompileSuccess = False
# Bar longer than actually required
stage_1_bar = tqdm(total=(2 * abs(start - end) + 2), mininterval=0, miniters=1, leave=True)
while firstCompileSuccess == False:
if highestValidScale == end:
Util.getLogger().error("Compilation not possible for any scale factor of variable \'X\'. Aborting code!")
return False
# Refactor and remove this try/catch block in the future.
try:
firstCompileSuccess = self.partialCompile(config.Encoding.fixed, config.Target.x86, highestValidScale, True, None, 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
except:
firstCompileSuccess = False
if firstCompileSuccess:
stage_1_bar.update(highestValidScale - end + 1)
break
highestValidScale -= 1
stage_1_bar.update(1)
lowestValidScale = end + 1
firstCompileSuccess = False
while firstCompileSuccess == False:
try:
firstCompileSuccess = self.partialCompile(config.Encoding.fixed, config.Target.x86, lowestValidScale, True, None, 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
except:
firstCompileSuccess = False
if firstCompileSuccess:
stage_1_bar.update(start - lowestValidScale + 2)
break
lowestValidScale += 1
stage_1_bar.update(1)
stage_1_bar.close()
# Ignored.
self.partialCompile(config.Encoding.fixed, config.Target.x86, lowestValidScale, True, None, -1, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
print("Stage II Exploration: Determining scale for all non-\'X\' variables...")
# The iterator logic is as follows:
# Search begins when the first valid scaling factor is found (runOnce returns True).
# Search ends when the execution fails on a particular scaling factor (runOnce returns False).
# This is the window where valid scaling factors exist and we
# select the one with the best accuracy.
numCodes = highestValidScale - lowestValidScale + 1
codeId = 0
codeIdToScaleFactorMap = {}
for i in tqdm(range(highestValidScale, lowestValidScale - 1, -1)):
if config.ddsEnabled:
Util.getLogger().debug("Testing with DDS and scale of X as " + str(i) + "\n")
else:
Util.getLogger().debug("Testing with max scale factor of " + str(i) + "\n")
codeId += 1
try:
compiled = self.partialCompile(
config.Encoding.fixed, config.Target.x86, i, False, codeId, -1 if codeId != numCodes else codeId, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
except: # If some code in the middle fails to compile.
codeId -=1
continue
if compiled == False:
return False
codeIdToScaleFactorMap[codeId] = i
print("Stage II Code Run Started...")
res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, codeIdToScaleFactorMap)
print("Stage II Code Run Completed!\n")
if exit == True or res == False:
return False
Util.getLogger().info("\nSearch completed\n")
Util.getLogger().info("----------------------------------------------\n\n")
Util.getLogger().info("Best performing scaling factors with accuracy, disagreement, reduced disagreement:")
self.sf = self.getBestScale()
if self.accuracy[self.sf][0] != lastStageAcc:
lastStageAcc = self.accuracy[self.sf][0]
elif config.fixedPointVbwIteration:
Util.getLogger().info("No difference in iteration %d Stage 2 and iteration %d Stage 1. Stopping search\n"%(fixedPointCounter-1, fixedPointCounter))
break
if config.vbwEnabled:
# Stage III exploration.
print("Stage III Exploration: Demoting variables one at a time...")
assert config.ddsEnabled, "Currently VBW on maxscale not supported"
if config.wordLength != 16:
assert False, "VBW mode only supported if native bitwidth is 16"
Util.getLogger().debug("Scales computed in native bitwidth. Starting exploration over other bitwidths.")
# We attempt to demote all possible variables in the code. We try out multiple different scales
# (controlled by config.offsetsPerDemotedVariable) for each demoted variable. When a variable is
# demoted, it is assigned a scale given by :
# demoted Scale = self.allScales[var] + 8 - offset
attemptToDemote = [var for var in self.variableToBitwidthMap if (var[-3:] != "val" and var not in self.demotedVarsList)]
numCodes = config.offsetsPerDemotedVariable * len(attemptToDemote) + ((9 - config.offsetsPerDemotedVariable) if 'X' in attemptToDemote else 0)
# 9 offsets tried for X while 'offsetsPerDemotedVariable' tried for other variables.
# We approximately club batchSize number of codes in one generated C++ code, so that one generated code does
# not become too large.
batchSize = int(np.ceil(50 / np.ceil(len(attemptToDemote) / 50)))
redBatchSize = np.max((batchSize, 16)) / config.offsetsPerDemotedVariable
totalSize = len(attemptToDemote)
numBatches = int(np.ceil(totalSize / redBatchSize))
self.varDemoteDetails = []
for i in tqdm(range(numBatches)):
Util.getLogger().info("=====\nBatch %i out of %d\n=====\n" %(i + 1, numBatches))
firstVarIndex = (totalSize * i) // numBatches
lastVarIndex = (totalSize * (i + 1)) // numBatches
demoteBatch = [attemptToDemote[i] for i in range(firstVarIndex, lastVarIndex)]
numCodes = config.offsetsPerDemotedVariable * len(demoteBatch) + ((9 - config.offsetsPerDemotedVariable) if 'X' in demoteBatch else 0)
# 9 offsets tried for X while 'config.offsetsPerDemotedVariable' tried for other variables.
self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, True, None, -1 if len(demoteBatch) > 0 else 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
codeId = 0
contentToCodeIdMap = {}
for demoteVar in demoteBatch:
# For each variable being demoted, we populate some variables containing information regarding demoted variable.
newbitwidths = dict(self.variableToBitwidthMap)
newbitwidths[demoteVar] = config.wordLength // 2
if demoteVar + "val" in newbitwidths:
newbitwidths[demoteVar + "val"] = config.wordLength // 2
for alreadyDemotedVars in self.demotedVarsList: # In subsequent iterations during fixed point compilation, this variable will have the variables demoted during the previous runs.
newbitwidths[alreadyDemotedVars] = config.wordLength // 2
demotedVarsList = [i for i in newbitwidths.keys() if newbitwidths[i] != config.wordLength]
demotedVarsOffsets = {}
for key in self.demotedVarsList:
demotedVarsOffsets[key] = self.demotedVarsOffsets[key]
contentToCodeIdMap[tuple(demotedVarsList)] = {}
# We try out multiple offsets for each variable to find best scale assignment for each variable.
for demOffset in (range(0, -config.offsetsPerDemotedVariable, -1) if demoteVar != 'X' else range(0, -9, -1)):
codeId += 1
for k in demotedVarsList:
if k not in self.demotedVarsList:
demotedVarsOffsets[k] = demOffset
contentToCodeIdMap[tuple(demotedVarsList)][demOffset] = codeId
compiled = self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, False, codeId, -1 if codeId != numCodes else codeId, dict(newbitwidths), list(demotedVarsList), dict(demotedVarsOffsets))
if compiled == False:
Util.getLogger().error("Variable bitwidth exploration resulted in a compilation error\n")
return False
res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, None, contentToCodeIdMap)
print("Stage IV Exploration: Cumulatively demoting variables...")
# Stage IV exploration.
# Again, we compute only a limited number of inference codes per generated C++ so as to not bloat up the memory usage of the compiler.
redBatchSize *= config.offsetsPerDemotedVariable
totalSize = len(self.varDemoteDetails)
numBatches = int(np.ceil(totalSize / redBatchSize))
sortedVars1 = []
sortedVars2 = []
for ((demoteVars, offset), _) in self.varDemoteDetails:
variableInMap = False
for demoteVar in demoteVars:
if demoteVar in self.varSizes:
variableInMap = True
if self.varSizes[demoteVar] >= Util.Config.largeVariableLimit:
sortedVars1.append((demoteVars, offset))
break
else:
sortedVars2.append((demoteVars, offset))
break
if not variableInMap:
sortedVars2.append((demoteVars, offset))
sortedVars = sortedVars1 + sortedVars2
self.varDemoteDetails = []
demotedVarsOffsets = dict(self.demotedVarsOffsets)
demotedVarsList = list(self.demotedVarsList)
demotedVarsListToOffsets = {}
# Knowing the accuracy when each single variable is demoted to 8-bits one at a time, we proceed to cumulatively
# demoting all of them one after the other ensuring accuracy of target code does not fall below a threshold. The
# following for loop controls generation of inference codes.
for i in tqdm(range(numBatches)):
Util.getLogger().info("=====\nBatch %i out of %d\n=====\n" %(i + 1, numBatches))
firstVarIndex = (totalSize * i) // numBatches
lastVarIndex = (totalSize * (i+1)) // numBatches
demoteBatch = [sortedVars[i] for i in range(firstVarIndex, lastVarIndex)]
self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, True, None, -1 if len(attemptToDemote) > 0 else 0, dict(self.variableToBitwidthMap), list(self.demotedVarsList), dict(self.demotedVarsOffsets))
contentToCodeIdMap = {}
codeId = 0
numCodes = len(demoteBatch)
for (demoteVars, offset) in demoteBatch:
newbitwidths = dict(self.variableToBitwidthMap)
for var in demoteVars:
if var not in self.demotedVarsList:
newbitwidths[var] = config.wordLength // 2
demotedVarsOffsets[var] = offset
if var not in demotedVarsList:
demotedVarsList.append(var)
codeId += 1
contentToCodeIdMap[tuple(demotedVarsList)] = {}
contentToCodeIdMap[tuple(demotedVarsList)][offset] = codeId
demotedVarsListToOffsets[tuple(demotedVarsList)] = dict(demotedVarsOffsets)
compiled = self.partialCompile(config.Encoding.fixed, config.Target.x86, self.sf, False, codeId, -1 if codeId != numCodes else codeId, dict(newbitwidths), list(demotedVarsList), dict(demotedVarsOffsets))
if compiled == False:
Util.getLogger().error("Variable bitwidth exploration resulted in another compilation error\n")
return False
res, exit = self.runAll(config.Encoding.fixed, config.DatasetType.training, None, contentToCodeIdMap, True)
if exit == True or res == False:
return False
# The following for loop controls how many variables are actually demoted in the final output code, which has
# as many variables as possible in 8-bits, while ensuring accuracy drop compared to floating point is reasonable:
okToDemote = ()
acceptedAcc = lastStageAcc
for ((demotedVars, _), metrics) in self.varDemoteDetails:
acc = metrics[0]
if self.problemType == config.ProblemType.classification and (self.flAccuracy - acc) > config.permittedClassificationAccuracyLoss:
break
elif self.problemType == config.ProblemType.regression and acc > config.permittedRegressionNumericalLossMargin:
break
else:
okToDemote = demotedVars
acceptedAcc = acc
self.demotedVarsList = [i for i in okToDemote] + [i for i in self.demotedVarsList]
self.demotedVarsOffsets.update(demotedVarsListToOffsets.get(okToDemote, {}))
if acceptedAcc != lastStageAcc:
lastStageAcc = acceptedAcc
else:
Util.getLogger().warning("No difference in iteration %d's stages 1 & 2. Stopping search."%fixedPointCounter)
break
if not config.vbwEnabled or not config.fixedPointVbwIteration:
break
return True