in tools/SeeDot/seedot/compiler/ir/irBuilder.py [0:0]
def visitBop2(self, node: AST.Bop2):
(prog_in_A, expr_in_A) = self.visit(node.expr1)
(prog_in_B, expr_in_B) = self.visit(node.expr2)
type_out = node.type
if node.op == SeeDotParser.ADD:
(op_ir, op_fn) = (IR.Op.Op['+'], operator.add)
funcName = "MatAdd"
elif node.op == SeeDotParser.SUB:
(op_ir, op_fn) = (IR.Op.Op['-'], operator.sub)
funcName = "MatSub"
# e : Int
if Type.isInt(type_out):
prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_B)
expr_out = IR.IntBop(expr_in_A, op_ir, expr_in_B)
# Just to be safe that the scaling factor of the integer variable is never tracked.
if isinstance(expr_in_A, IR.Var):
assert expr_in_A.idf not in self.varScales and expr_in_A.idf not in self.varIntervals
if isinstance(expr_in_B, IR.Var):
assert expr_in_B.idf not in self.varScales and expr_in_B.idf not in self.varIntervals
# e : Tensor(), or Tensor(..)
else:
assert type_out.dim == 2 or (type_out.dim == 4 and config.vbwEnabled), "Addition/subtraction of tensors is currently only supported for 2D tensors. Addition for 4D tensors is supported when VBW is enabled"
type_A = node.expr1.type
type_B = node.expr2.type
assert (not type_out.dim == 4) or (type_A.dim == type_B.dim and expr_in_A.idf not in self.globalVars and expr_in_B.idf not in self.globalVars and node.op == SeeDotParser.ADD), "For 4D operation, no broadcasting supported, inputs should not be model parameters, and operation cannot be subtraction"
# Depending on whether one of the inputs is a model parameter, change the function name so that the model parameter is read differently in the arduino codegen. No difference in case of x86 code.
c = ''
if op_fn == operator.add:
if expr_in_A.idf in self.globalVars:
c += 'C'
else:
c += 'N'
if expr_in_B.idf in self.globalVars:
c += 'C'
else:
c += 'N'
# If one of the inputs is a scalar, the operator will broadcast that input.
if type_A.dim == 0:
funcName += 'BroadCastA'
c = ''
elif type_B.dim == 0:
funcName += 'BroadCastB'
c = ''
expr_out = self.getTempVar()
# Read input scale.
bitwidth_in_A, scale_in_A = self.getBitwidthAndScale(expr_in_A.idf)
bitwidth_in_B, scale_in_B = self.getBitwidthAndScale(expr_in_B.idf)
# Read output scale.
if self.ddsEnabled:
bitwidth_out, scale_out = self.getBitwidthAndScale(expr_out.idf)
else:
bitwidth_out = config.wordLength // 2 if expr_out.idf in self.demotedVarsList else config.wordLength
scale_out = None
# Compute scaling hyperparameters given input and output scales. If static scaling of old SeeDot is used, also compute the output scale and bit-width.
(scale_out_unadjusted, intv_out, [shr_A, shr_B, shr_out]) = self.getScaleForAddAndSub(scale_in_A, scale_in_B, scale_out, op_fn)
if scale_out is None:
scale_out = scale_out_unadjusted
intv_in_A, intv_in_B = self.varIntervals[expr_in_A.idf], self.varIntervals[expr_in_B.idf]
demoteLog = shr_out - 8 if shr_out >= 8 else 0
shr_out = min(shr_out, 8)
irdemote = self.formatShr(demoteLog)
if type_out.dim == 2:
[I, J] = type_out.shape
elif type_out.dim == 4:
[N, H, W, C] = type_out.shape
else:
assert False, "Unsupported dimension for addition"
shr_A = self.formatShr(shr_A)
shr_B = self.formatShr(shr_B)
shr_out = self.formatShr(shr_out)
expr_in_A.inputVar = False
expr_in_B.inputVar = False
expr_out.inputVar = False
comment = IR.Comment(expr_in_A.idf + ' ' +
op_ir.name + ' ' + expr_in_B.idf, self.counter_inst+1)
self.allDepths[self.counter_inst+1] = self.curDepth
# Generate output function call depending on dimensionality of the input / output.
if type_out.dim == 2:
funcCall = IR.FuncCall(funcName + c, {
expr_in_A: "A",
expr_in_B: "B",
expr_out: "C",
IR.Int(I): "I",
IR.Int(J): "J",
shr_A: "shrA",
shr_B: "shrB",
shr_out: "shrC"
}) if not self.vbwEnabled else IR.FuncCall(funcName + c + ("<int%d_t, int%d_t, int%d_t, int%d_t>" % (bitwidth_in_A, bitwidth_in_B, self.getTempBitwidth(bitwidth_in_A, bitwidth_in_B, "add", bitwidth_out), bitwidth_out)), {
expr_in_A: "A",
expr_in_B: "B",
expr_out: "C",
IR.Int(I): "I",
IR.Int(J): "J",
shr_A: "shrA",
shr_B: "shrB",
shr_out: "shrC",
irdemote: "demote"
})
elif type_out.dim == 4:
funcCall = IR.FuncCall(funcName + "4", {
expr_in_A: "A",
expr_in_B: "B",
expr_out: "X",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(C): "C",
shr_A: "shrA",
shr_B: "shrB",
shr_out: "shrC",
}) if not self.vbwEnabled else IR.FuncCall(funcName + "4" + ("<int%d_t, int%d_t, int%d_t, int%d_t>" % (bitwidth_in_A, bitwidth_in_B, self.getTempBitwidth(bitwidth_in_A, bitwidth_in_B, "add", bitwidth_out), bitwidth_out)), {
expr_in_A: "A",
expr_in_B: "B",
expr_out: "X",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(C): "C",
shr_A: "shrA",
shr_B: "shrB",
shr_out: "shrC",
irdemote: "demote"
})
self.counter_inst += 1
self.updateLiveRange([expr_in_A, expr_in_B, expr_out])
if type_out.dim == 4:
if expr_in_A.idf not in self.globalVars and bitwidth_in_A == bitwidth_out:
self.setMemorySharableVariables(expr_in_A, expr_out)
elif expr_in_B.idf not in self.globalVars and bitwidth_in_B == bitwidth_out:
self.setMemorySharableVariables(expr_in_B, expr_out)
# Profile the output variable in the floating point version to obtain fixed-point scale.
if type_out.dim == 2:
profile = IR.FuncCall("Profile2", {
expr_out: "Var",
IR.Int(I): "I",
IR.Int(J): "J",
IR.String(expr_out): "VarName"
})
elif type_out.dim == 4:
profile = IR.FuncCall("Profile4", {
expr_out: "Var",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(C): "C",
IR.String(expr_out): "VarName"
})
else:
assert False, "Illegal number of dimensions"
if forFloat():
self.independentVars.append(expr_out.idf)
# The theoretical output scale in scale_raw might be different than profiled scale scale_out.
# We perform a scale adjustment in this case for correctness.
# TODO: Introduce a post-processing pass to merge consecutive scale adjustments hence generated.
if type_out.dim == 2:
adjust = []
if forFixed():
if scale_out_unadjusted != scale_out:
if scale_out_unadjusted > scale_out:
diff_scale = 2 ** (scale_out_unadjusted - scale_out)
adjust = [IR.FuncCall("AdjustScaleShl" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
expr_out: "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})]
elif scale_out_unadjusted < scale_out:
diff_scale = 2 ** (scale_out - scale_out_unadjusted)
adjust = [IR.FuncCall("AdjustScaleShr" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
expr_out: "A",
IR.Int(I): "I",
IR.Int(J): "J",
IR.Int(diff_scale): "scale"
})]
elif type_out.dim == 4:
adjust = []
if forFixed():
if scale_out_unadjusted != scale_out:
if scale_out_unadjusted > scale_out:
diff_scale = 2 ** (scale_out_unadjusted - scale_out)
adjust = [IR.FuncCall("AdjustScaleShl" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
expr_out: "A",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(C): "C",
IR.Int(diff_scale): "scale"
})]
elif scale_out_unadjusted < scale_out:
diff_scale = 2 ** (scale_out - scale_out_unadjusted)
adjust = [IR.FuncCall("AdjustScaleShr" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
expr_out: "A",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(C): "C",
IR.Int(diff_scale): "scale"
})]
else:
assert False, "Illegal number of dimensions"
prog_bop = IR.Prog([comment, funcCall, profile] if forFloat() and self.ddsEnabled else [comment, funcCall] + adjust)
prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_B, prog_bop)
# Updating metadata.
self.varDeclarations[expr_out.idf] = type_out
self.varScales[expr_out.idf] = scale_out
self.varIntervals[expr_out.idf] = intv_out
# Print log.
self.log.print(comment.msg)
self.log.print("\tInput1: scale = %d, interval = [%d, %d]" % (
(self.varScales[expr_in_A.idf],) + self.varIntervals[expr_in_A.idf]))
self.log.print("\tInput2: scale = %d, interval = [%d, %d]" % (
(self.varScales[expr_in_B.idf],) + self.varIntervals[expr_in_B.idf]))
self.log.print("\tOutput: scale = %d, interval = [%d, %d]" % (
(self.varScales[expr_out.idf],) + self.varIntervals[expr_out.idf]))
return (prog_out, expr_out)