in tools/SeeDot/seedot/compiler/ir/irBuilder.py [0:0]
def visitMbconv(self, node: AST.MBConv):
if not (config.ddsEnabled and config.vbwEnabled):
assert False, "MBConv is currently only supported if VBW and DDS modes are switched on"
assert forX86() or forM3(), "MBConv not implemented for Arduino devices"
# Process all inputs for MBConv.
(prog_in_A, expr_in_A) = self.visit(node.expr1)
(prog_in_F1, expr_in_F1) = self.visit(node.exprF1)
(prog_in_W1, expr_in_W1) = self.visit(node.exprW1)
(prog_in_B1, expr_in_B1) = self.visit(node.exprB1)
(prog_in_F2, expr_in_F2) = self.visit(node.exprF2)
(prog_in_W2, expr_in_W2) = self.visit(node.exprW2)
(prog_in_B2, expr_in_B2) = self.visit(node.exprB2)
(prog_in_F3, expr_in_F3) = self.visit(node.exprF3)
(prog_in_W3, expr_in_W3) = self.visit(node.exprW3)
(prog_in_B3, expr_in_B3) = self.visit(node.exprB3)
[expr_treeSum, expr_out] = self.getTempVars(2)
[expr_bufX, expr_bufT] = self.getTempVars(2)
[N, H, W, Cin] = node.expr1.type.shape
[_, _, _, _, Ct] = node.exprF1.type.shape
[_, Hf, Wf, _, _] = node.exprF2.type.shape
[_, _, _, _, Cout] = node.exprF3.type.shape
type_treeSum = Type.Tensor([np.max((Hf * Wf, Ct, Cin))])
type_out = node.type
type_bufX = Type.Tensor([Hf, W, Ct])
type_bufT = Type.Tensor([Ct])
# Process bit-width and scales for all inputs.
bitwidth_in_A, scale_in_A = self.getBitwidthAndScale(expr_in_A.idf)
bitwidth_in_F1, scale_in_F1 = self.getBitwidthAndScale(expr_in_F1.idf)
bitwidth_in_W1, scale_in_W1 = self.getBitwidthAndScale(expr_in_W1.idf)
bitwidth_in_B1, scale_in_B1 = self.getBitwidthAndScale(expr_in_B1.idf)
bitwidth_in_F2, scale_in_F2 = self.getBitwidthAndScale(expr_in_F2.idf)
bitwidth_in_W2, scale_in_W2 = self.getBitwidthAndScale(expr_in_W2.idf)
bitwidth_in_B2, scale_in_B2 = self.getBitwidthAndScale(expr_in_B2.idf)
bitwidth_in_F3, scale_in_F3 = self.getBitwidthAndScale(expr_in_F3.idf)
bitwidth_in_W3, scale_in_W3 = self.getBitwidthAndScale(expr_in_W3.idf)
bitwidth_in_B3, scale_in_B3 = self.getBitwidthAndScale(expr_in_B3.idf)
bitwidth_out, scale_out = self.getBitwidthAndScale(expr_out.idf)
shr = [0 for i in range(9)]
shl = [0 for i in range(9)]
# Compute intermediate scales and scaling factors for all operations which are included in MBConv.
if not forFloat():
# Stage 1 Step 1: Multiplication
bitwidth_u1 = bitwidth_in_A + bitwidth_in_F1 - 1
bitwidth_u1_code = self.getTempBitwidth(bitwidth_in_A, bitwidth_in_F1, "mul")
scale_u1 = scale_in_A + scale_in_F1
# Stage 1 Step 2: Tree Sum
d1 = int(np.ceil(np.log2(Cin)))
scale_u1 = scale_u1 + d1
# Stage 1 Step 3: Batch Normalisation and ReLU6
bitwidth_add1 = np.max((bitwidth_in_A, bitwidth_in_F1))
bitwidth_reduction = config.wordLength - bitwidth_add1
_, scale_add1 = self.getBitwidthAndScale(expr_out.idf + "t1") + bitwidth_reduction
shr[0] = (scale_add1 - scale_u1)
shr[1] = (scale_add1 - scale_in_B1)
bitwidth_mul1 = bitwidth_add1 + bitwidth_in_W1 - 1
bitwidth_mul1_code = self.getTempBitwidth(bitwidth_add1, bitwidth_in_W1, "mul")
scale_mul1 = scale_add1 + scale_in_W1
six1 = 6 * (2 ** -scale_mul1)
bitwidth_x = np.max((bitwidth_add1, bitwidth_in_W1))
scale_x = -bitwidth_x + 1 + int(np.floor(np.log2(6)) + 1)
scale_shift = scale_x - scale_mul1
shr[2] = scale_shift
# Stage 2 Step 4: Multiplication
bitwidth_u2 = bitwidth_x + bitwidth_in_F2 - 1
bitwidth_u2_code = self.getTempBitwidth(bitwidth_x, bitwidth_in_F2, "mul")
scale_u2 = scale_x + scale_in_F2
# Stage 2 Step 5: Tree Sum
d2 = int(np.ceil(np.log2(Hf * Wf)))
scale_u2 = scale_u2 + d2
# Stage 2 Step 6: Batch Normalisation and ReLU6
bitwidth_add2 = np.max((bitwidth_x, bitwidth_in_F2))
bitwidth_reduction = config.wordLength - bitwidth_add2
_, scale_add2 = self.getBitwidthAndScale(expr_out.idf + "t3") + bitwidth_reduction
shr[3] = (scale_add2 - scale_u2)
shr[4] = (scale_add2 - scale_in_B2)
bitwidth_mul2 = bitwidth_add2 + bitwidth_in_W2 - 1
bitwidth_mul2_code = self.getTempBitwidth(bitwidth_add2, bitwidth_in_W2, "mul")
scale_mul2 = scale_add2 + scale_in_W2
six2 = 6 * (2 ** -scale_mul2)
bitwidth_t = np.max((bitwidth_add2, bitwidth_in_W2))
scale_t = -bitwidth_t + 1 + int(np.floor(np.log2(6)) + 1)
scale_shift = scale_t - scale_mul2
shr[5] = scale_shift
# Stage 3 Step 7: Multiplication
bitwidth_u3 = bitwidth_t + bitwidth_in_F3 - 1
bitwidth_u3_code = self.getTempBitwidth(bitwidth_t, bitwidth_in_F3, "mul")
scale_u3 = scale_t + scale_in_F3
# Stage 3 Step 8: Tree Sum
d3 = int(np.ceil(np.log2(Ct)))
scale_u3 = scale_u3 + d3
# Stage 3 Step 9: Batch Normalisation
bitwidth_add3 = np.max((bitwidth_t, bitwidth_in_F3))
bitwidth_reduction = config.wordLength - bitwidth_add3
_, scale_add3 = self.getBitwidthAndScale(expr_out.idf + "t5") + bitwidth_reduction
shr[6] = (scale_add3 - scale_u3)
shr[7] = (scale_add3 - scale_in_B3)
bitwidth_mul3 = bitwidth_add3 + bitwidth_in_W3 - 1
bitwidth_mul3_code = self.getTempBitwidth(bitwidth_add3, bitwidth_in_W3, "mul")
scale_mul3 = scale_add3 + scale_in_W3
scale_reduction = scale_out - scale_mul3
shr[8] = scale_reduction
for i in range(9):
shl[i] = -shr[i]
else:
d1 = int(np.ceil(np.log2(Cin)))
d2 = int(np.ceil(np.log2(Hf * Wf)))
d3 = int(np.ceil(np.log2(Ct)))
# In floating-point mode, none of the following values matter. Setting them to dummy values.
for i in range(9):
shr[i] = 1
shl[i] = 1
bitwidth_u = bitwidth_t = bitwidth_x = config.wordLength
bitwidth_u1_code = bitwidth_u2_code = bitwidth_u3_code = config.wordLength
six1 = six2 = 6.0
scale_x = scale_t = 0
for i in range(9):
if shr[i] >= 0:
shr[i] = self.formatShr(shr[i], saturate=False)
shl[i] = self.formatShr(0)
else:
shr[i] = self.formatShr(0)
shl[i] = self.formatShr(shl[i], saturate=False)
expr_in_A.inputVar = False
expr_in_F1.inputVar = False
expr_in_W1.inputVar = False
expr_in_B1.inputVar = False
expr_in_F2.inputVar = False
expr_in_W2.inputVar = False
expr_in_B2.inputVar = False
expr_in_F3.inputVar = False
expr_in_W3.inputVar = False
expr_in_B3.inputVar = False
expr_out.inputVar = False
expr_treeSum.inputVar = False
expr_bufT.inputVar = False
expr_bufX.inputVar = False
bitwidth_u = np.max((bitwidth_u1_code, bitwidth_u2_code, bitwidth_u3_code))
# Setting metadata.
if forFixed():
self.varsForBitwidth[expr_treeSum.idf] = bitwidth_u
self.varsForBitwidth[expr_bufT.idf] = bitwidth_t
self.varsForBitwidth[expr_bufX.idf] = bitwidth_x
comment = IR.Comment('MBconv(%s)' %(expr_in_A.idf), self.counter_inst+1)
self.allDepths[self.counter_inst+1] = self.curDepth
argMap = {
expr_in_A: "A",
expr_in_F1: "F1",
expr_in_W1: "BN1W",
expr_in_B1: "BN1B",
expr_in_F2: "F2",
expr_in_W2: "BN2W",
expr_in_B2: "BN2B",
expr_in_F3: "F3",
expr_in_W3: "BN3W",
expr_in_B3: "BN3B",
expr_out: "C",
expr_bufX: "X",
expr_bufT: "T",
expr_treeSum: "U",
IR.Int(N): "N",
IR.Int(H): "H",
IR.Int(W): "W",
IR.Int(Cin): "Cin",
IR.Int(Ct): "Ct",
IR.Int(Hf): "HF",
IR.Int(Wf): "WF",
IR.Int(Cout): "Cout",
IR.Int(type_out.shape[1]): "Hout",
IR.Int(type_out.shape[2]): "Wout",
IR.Int(node.padding[0]): "HPADL",
IR.Int(node.padding[1]): "HPADR",
IR.Int(node.padding[2]): "WPADL",
IR.Int(node.padding[3]): "WPADR",
IR.Int(node.stride[0]): "HSTR",
IR.Int(node.stride[1]): "WSTR",
IR.Int(d1): "D1",
IR.Int(d2): "D2",
IR.Int(d3): "D3",
IR.Int(six1): "SIX_1",
IR.Int(six2): "SIX_2",
}
for i in range(9):
argMap[shr[i]] = "shr%d" % (i+1)
for i in range(9):
argMap[shl[i]] = "shl%d" % (i+1)
# These are used to optimise the m3 codegen to club multiple scale modification operators into one for faster code.
self.biasShifts[expr_in_B1.idf] = int(np.log2(shr[1].n)) - int(np.log2(shl[1].n))
self.biasShifts[expr_in_B2.idf] = int(np.log2(shr[4].n)) - int(np.log2(shl[4].n))
self.biasShifts[expr_in_B3.idf] = int(np.log2(shr[7].n)) - int(np.log2(shl[7].n))
if forFloat():
argMap[IR.String(expr_out)] = "name"
# Generating the argument map which is used in the codegen.
localVarMap = {expr_treeSum.idf: type_treeSum, expr_bufX.idf: type_bufX, expr_bufT.idf: type_bufT}
if forFloat():
funcCall = IR.FuncCall("MBConv", argMap)
else:
templateArgs = ("<int%s_t" + (", int%s_t" * 16) + ">") % (bitwidth_in_A, bitwidth_in_F1, bitwidth_in_W1, bitwidth_in_B1, bitwidth_in_F2, bitwidth_in_W2, bitwidth_in_B2, bitwidth_in_F3, bitwidth_in_W3, bitwidth_in_B3, bitwidth_out, bitwidth_x, bitwidth_t, bitwidth_u, bitwidth_mul1_code, bitwidth_mul2_code, bitwidth_mul3_code)
funcCall = IR.FuncCall("MBConv" + templateArgs, argMap)
self.counter_inst += 1
self.updateLiveRange([expr_in_A, expr_in_F1, expr_in_F2, expr_in_F3, expr_in_W1, expr_in_W2, expr_in_W3, expr_in_B1, expr_in_B2, expr_in_B3, expr_out, expr_treeSum, expr_bufX, expr_bufT])
# Profiling the output variable in floating-point mode for computing the scale of the fixed-point code.
profile = IR.FuncCall("Profile4", {
expr_out: "Var",
IR.Int(N): "I",
IR.Int(type_out.shape[1]): "J",
IR.Int(type_out.shape[2]): "K",
IR.Int(Cout): "L",
IR.String(expr_out): "VarName"
})
if forFloat():
self.independentVars.append(expr_out.idf)
prog_mbconv = IR.Prog([comment, funcCall, profile] if forFloat() and self.ddsEnabled else [comment, funcCall])
prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_F1, prog_in_W1, prog_in_B1, prog_in_F2, prog_in_W2, prog_in_B2, prog_in_F3, prog_in_W3, prog_in_B3, prog_mbconv)
# Update metadata.
self.varDeclarations[expr_out.idf] = type_out
self.varDeclarations[expr_treeSum.idf] = type_treeSum
self.varDeclarations[expr_bufX.idf] = type_bufX
self.varDeclarations[expr_bufT.idf] = type_bufT
self.varScales[expr_out.idf] = scale_out
self.varScales[expr_treeSum.idf] = 0 # It changes across three stages above and an exact value not required outside of this method.
self.varScales[expr_bufX.idf] = scale_x
self.varScales[expr_bufT.idf] = scale_t
# Intervals not needed necesarily for the compiler to run, updating this variable for being compatible with old SeeDot (PLDI '19).
self.varIntervals[expr_out.idf] = (0, 0)
self.varIntervals[expr_treeSum.idf] = (0, 0)
self.varIntervals[expr_bufX.idf] = (0, 0)
self.varIntervals[expr_bufT.idf] = (0, 0)
# Printing log.
self.log.print(comment.msg)
self.log.print("\tInput1: scale = %d, interval = [%d, %d]" % (
(self.varScales[expr_in_A.idf],) + self.varIntervals[expr_in_A.idf]))
self.log.print("\tOutput: scale = %d, interval = [%d, %d]" % (
(self.varScales[expr_out.idf],) + self.varIntervals[expr_out.idf]))
return (prog_out, expr_out)