in tools/SeeDot/seedot/compiler/ir/irBuilder.py [0:0]
def visitLet(self, node: AST.Let):
# Visit RHS of the let statement.
(prog_decl, expr_decl) = self.visit(node.decl)
type_decl = node.decl.type
idf = node.name
# e1 : Int
if Type.isInt(type_decl):
# LHS is a new integer variable and needs to be assigned to the list of variables.
self.varDeclarations[idf] = Type.Int()
self.internalVars.append(idf)
# Visit remainder of the program.
(prog_in, expr_in) = self.visit(node.expr)
cmd = IR.Assn(IR.Var(idf), expr_decl)
prog_let = IR.Prog([cmd])
prog_out = IRUtil.concatPrograms(prog_decl, prog_let, prog_in)
return (prog_out, expr_in)
# Left Splice case.
elif node.leftSplice is not None:
# We have to assign the value of decl (RHS) into a splice of the LHS variable.
parentVar = node.name
while parentVar in self.substitutions:
parentVar = self.substitutions[parentVar] # Done as all metadata is stored at the end of the substitution chain.
# Assign the RHS to a splice of LHS.
(prog_splice, expr_splice) = self.visitLeftSplice(node.leftSplice, expr_decl, self.varDeclarations[parentVar])
(prog_in, expr_in) = self.visit(node.expr)
# Profile the LHS as the value would have been updated, hence the scale required for LHS in the floating-point code may be different.
profile = IR.Prog([])
if forFloat():
profile = IR.Prog([IR.FuncCall("Profile2", {
expr_decl: "Var",
IR.Int(node.decl.type.shape[0]): "I",
IR.Int(node.decl.type.shape[1]): "J",
IR.String(expr_splice): "VarName"
})])
if forFloat():
self.independentVars.append(expr_splice.idf)
prog_out = IRUtil.concatPrograms(prog_decl, prog_splice, profile, prog_in)
return (prog_out, expr_in)
# e1 : Tensor{(),(..)}
else:
# Compute the scale of the LHS variable. RHS/decl may have a different bit-width, hence the scale of LHS has to be adjusted accordingly.
self.varScales[idf] = self.varScales[expr_decl.idf] + (config.wordLength//2 + self.demotedVarsOffsets.get(idf, 0) if idf in self.demotedVarsList else 0)
self.varIntervals[idf] = self.varIntervals[expr_decl.idf]
# If LHS is demoted to lower bit-width, the RHS should also be in a lower bit-width, so scale of RHS is also adjusted.
if idf in self.demotedVarsList:
self.varScales[expr_decl.idf] += self.demotedVarsOffsets[idf] + config.wordLength // 2
self.demotedVarsList.append(expr_decl.idf)
self.demotedVarsOffsets[expr_decl.idf] = self.demotedVarsOffsets[idf]
self.varsForBitwidth[expr_decl.idf] = config.wordLength // 2
else:
if expr_decl.idf not in self.varsForBitwidth:
self.varsForBitwidth[expr_decl.idf] = config.wordLength
# For input X, scale is computed as follows.
if idf == "X" and self.scaleForX is not None:
self.varScales[idf] = self.scaleForX + (config.wordLength // 2 + self.demotedVarsOffsets.get("X", 0) if 'X' in self.demotedVarsList else 0)
# If the let statement is a model parameter declaration, then the following is invoked.
if isinstance(node.decl, AST.Decl):
self.globalVars.append(idf)
# TODO: Do I need to update varDeclarations or is it handled already?
self.varDeclarations[idf] = node.decl.type
expr_decl.idf = idf
expr_decl.inputVar = True
# For mutable variables of a loop, such variables are substituted later and the details are captured here.
if idf in self.mutableVars:
if forFloat():
if expr_decl.idf != idf:
self.substitutions[expr_decl.idf] = idf
expr_decl.idf = idf
# In fixed-point mode, for mutable variables the scales need to be adjusted which is done here.
if forFixed() and idf in self.mutableVars:
# Add a loop to adjust the scale back to the original one.
curr_scale = self.varScales[idf]
idfs = idf
while idfs in self.substitutions.keys():
idfs = self.substitutions[idfs]
# Read profiled scale of the LHS (profile assumes 16-bit variables) and compute final scale depending on actual bitwidth of LHS.
if self.ddsEnabled:
_, raw_new_scale = self.getBitwidthAndScale(idfs)
new_scale = raw_new_scale + (config.wordLength // 2 + self.demotedVarsOffsets[idfs] if idfs in self.demotedVarsList else 0)
new_intv = (0, 0)
else:
[minVal, maxVal] = self.mutableVarsProfile[0] # TODO: This function may not work for multiple loops in a code.
new_scale = self.getScale(max(abs(minVal), abs(maxVal))) + (config.wordLength // 2 + self.demotedVarsOffsets[idfs] if idfs in self.demotedVarsList else 0)
new_intv = self.getInterval(new_scale, minVal, maxVal)
diff_scale = 2 ** (curr_scale - new_scale) if curr_scale > new_scale else 2 ** (new_scale - curr_scale)
[I, J] = type_decl.shape
bitwidth_decl, scale_decl = self.getBitwidthAndScale(expr_decl.idf)
# The mutable loop variable needs to have it's scale adjusted so that it remains the same across iterations for correctness.
adjust = []
if curr_scale != new_scale:
if curr_scale > new_scale:
adjust = [IR.FuncCall("AdjustScaleShl", {
IR.Var(idf): "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})] if not self.vbwEnabled else [IR.FuncCall("AdjustScaleShl<int%d_t>"%(bitwidth_decl), {
IR.Var(idf): "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})]
elif curr_scale < new_scale:
adjust = [IR.FuncCall("AdjustScaleShr", {
IR.Var(idf): "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})] if not self.vbwEnabled else [IR.FuncCall("AdjustScaleShr<int%d_t>"%(bitwidth_decl), {
IR.Var(idf): "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})]
prog_for_mutable = IR.Prog(adjust)
# Reset the self.scale value to the profile generated one.
self.varScales[idf] = new_scale
self.varIntervals[idf] = new_intv
else:
prog_for_mutable = IR.Prog([])
# In floating point mode, the details of substitutions are stored for use in the fixed-point version.
# Independent variables (which are profiled) are also stored in a way to avoid duplication (profiling two names of the same entity).
if forFloat():
if expr_decl.idf != idf:
if idf in self.substitutions.keys():
assert False, "What kind of subtitutions are going on?"
self.substitutions[idf] = expr_decl.idf
# To ensure loop variable is correctly fed for data driven scaling.
for i in range(len(self.independentVars)):
while self.independentVars[i] in self.substitutions.keys():
self.independentVars[i] = self.substitutions[self.independentVars[i]]
(prog_in, expr_in) = self.visit(node.expr)
# TODO: When is this triggered and why is this required?
if forFixed() and idf in self.mutableVars:
getLogger().warning("TODO: Fix this if condition")
idfs = idf
while idfs in self.substitutions.keys():
idfs = self.substitutions[idfs]
if self.ddsEnabled:
_, raw_new_scale = self.getBitwidthAndScale(idfs)
new_scale = raw_new_scale + (config.wordLength // 2 + self.demotedVarsOffsets[idfs] if idfs in self.demotedVarsList else 0)
new_intv = (0, 0)
else:
[minVal, maxVal] = self.mutableVarsProfile[0]
new_scale = self.getScale(max(abs(minVal), abs(maxVal)))
new_intv = self.getInterval(new_scale, minVal, maxVal)
self.varScales[expr_decl.idf] = new_scale
self.varIntervals[expr_decl.idf] = new_intv
prog_decl = IRUtil.concatPrograms(
prog_decl, IR.Prog([prog_for_mutable]))
# Perform substitutions to consolidate generated names and user-provided names.
prog_in = prog_in.subst(idf, expr_decl)
expr_in = expr_in.subst(idf, expr_decl)
# Consolidate the information about live ranges for lhs and rhs, given the substitutions performed above.
if idf != expr_decl.idf and idf in self.varLiveIntervals and expr_decl.idf in self.varLiveIntervals:
self.varLiveIntervals[idf] = [min(self.varLiveIntervals[idf][0], self.varLiveIntervals[expr_decl.idf][0]), max(self.varLiveIntervals[idf][1], self.varLiveIntervals[expr_decl.idf][1])]
self.varLiveIntervals[expr_decl.idf] = list(self.varLiveIntervals[idf])
prog_out = IRUtil.concatPrograms(prog_decl, prog_in)
return (prog_out, expr_in)