def visitMbconv()

in tools/SeeDot/seedot/compiler/ir/irBuilder.py [0:0]


    def visitMbconv(self, node: AST.MBConv):
        if not (config.ddsEnabled and config.vbwEnabled):
            assert False, "MBConv is currently only supported if VBW and DDS modes are switched on"

        assert forX86() or forM3(), "MBConv not implemented for Arduino devices"

        # Process all inputs for MBConv.
        (prog_in_A, expr_in_A) = self.visit(node.expr1)

        (prog_in_F1, expr_in_F1) = self.visit(node.exprF1)
        (prog_in_W1, expr_in_W1) = self.visit(node.exprW1)
        (prog_in_B1, expr_in_B1) = self.visit(node.exprB1)

        (prog_in_F2, expr_in_F2) = self.visit(node.exprF2)
        (prog_in_W2, expr_in_W2) = self.visit(node.exprW2)
        (prog_in_B2, expr_in_B2) = self.visit(node.exprB2)

        (prog_in_F3, expr_in_F3) = self.visit(node.exprF3)
        (prog_in_W3, expr_in_W3) = self.visit(node.exprW3)
        (prog_in_B3, expr_in_B3) = self.visit(node.exprB3)

        [expr_treeSum, expr_out] = self.getTempVars(2)
        [expr_bufX, expr_bufT] = self.getTempVars(2)

        [N, H, W, Cin] = node.expr1.type.shape
        [_, _, _, _, Ct] = node.exprF1.type.shape
        [_, Hf, Wf, _, _] = node.exprF2.type.shape
        [_, _, _, _, Cout] = node.exprF3.type.shape

        type_treeSum = Type.Tensor([np.max((Hf * Wf, Ct, Cin))])
        type_out = node.type
        type_bufX = Type.Tensor([Hf, W, Ct])
        type_bufT = Type.Tensor([Ct])

        # Process bit-width and scales for all inputs.
        bitwidth_in_A, scale_in_A = self.getBitwidthAndScale(expr_in_A.idf)

        bitwidth_in_F1, scale_in_F1 = self.getBitwidthAndScale(expr_in_F1.idf)
        bitwidth_in_W1, scale_in_W1 = self.getBitwidthAndScale(expr_in_W1.idf)
        bitwidth_in_B1, scale_in_B1 = self.getBitwidthAndScale(expr_in_B1.idf)

        bitwidth_in_F2, scale_in_F2 = self.getBitwidthAndScale(expr_in_F2.idf)
        bitwidth_in_W2, scale_in_W2 = self.getBitwidthAndScale(expr_in_W2.idf)
        bitwidth_in_B2, scale_in_B2 = self.getBitwidthAndScale(expr_in_B2.idf)

        bitwidth_in_F3, scale_in_F3 = self.getBitwidthAndScale(expr_in_F3.idf)
        bitwidth_in_W3, scale_in_W3 = self.getBitwidthAndScale(expr_in_W3.idf)
        bitwidth_in_B3, scale_in_B3 = self.getBitwidthAndScale(expr_in_B3.idf)

        bitwidth_out, scale_out = self.getBitwidthAndScale(expr_out.idf)

        shr = [0 for i in range(9)]
        shl = [0 for i in range(9)]

        # Compute intermediate scales and scaling factors for all operations which are included in MBConv.
        if not forFloat():
            # Stage 1 Step 1: Multiplication
            bitwidth_u1 = bitwidth_in_A + bitwidth_in_F1 - 1
            bitwidth_u1_code = self.getTempBitwidth(bitwidth_in_A, bitwidth_in_F1, "mul")
            scale_u1 = scale_in_A + scale_in_F1

            # Stage 1 Step 2: Tree Sum
            d1 = int(np.ceil(np.log2(Cin)))
            scale_u1 = scale_u1 + d1

            # Stage 1 Step 3: Batch Normalisation and ReLU6
            bitwidth_add1 = np.max((bitwidth_in_A, bitwidth_in_F1))
            bitwidth_reduction = config.wordLength - bitwidth_add1
            _, scale_add1 = self.getBitwidthAndScale(expr_out.idf + "t1") + bitwidth_reduction 
            shr[0] = (scale_add1 - scale_u1)
            shr[1] = (scale_add1 - scale_in_B1)
            bitwidth_mul1 = bitwidth_add1 + bitwidth_in_W1 - 1
            bitwidth_mul1_code = self.getTempBitwidth(bitwidth_add1, bitwidth_in_W1, "mul")
            scale_mul1 = scale_add1 + scale_in_W1
            six1 = 6 * (2 ** -scale_mul1)
            bitwidth_x = np.max((bitwidth_add1, bitwidth_in_W1))
            scale_x = -bitwidth_x + 1 + int(np.floor(np.log2(6)) + 1)
            scale_shift = scale_x - scale_mul1
            shr[2] = scale_shift

            # Stage 2 Step 4: Multiplication
            bitwidth_u2 = bitwidth_x + bitwidth_in_F2 - 1
            bitwidth_u2_code = self.getTempBitwidth(bitwidth_x, bitwidth_in_F2, "mul")
            scale_u2 = scale_x + scale_in_F2

            # Stage 2 Step 5: Tree Sum
            d2 = int(np.ceil(np.log2(Hf * Wf)))
            scale_u2 = scale_u2 + d2

            # Stage 2 Step 6: Batch Normalisation and ReLU6
            bitwidth_add2 = np.max((bitwidth_x, bitwidth_in_F2))
            bitwidth_reduction = config.wordLength - bitwidth_add2
            _, scale_add2 = self.getBitwidthAndScale(expr_out.idf + "t3") + bitwidth_reduction 
            shr[3] = (scale_add2 - scale_u2)
            shr[4] = (scale_add2 - scale_in_B2)
            bitwidth_mul2 = bitwidth_add2 + bitwidth_in_W2 - 1
            bitwidth_mul2_code = self.getTempBitwidth(bitwidth_add2, bitwidth_in_W2, "mul")
            scale_mul2 = scale_add2 + scale_in_W2
            six2 = 6 * (2 ** -scale_mul2)
            bitwidth_t = np.max((bitwidth_add2, bitwidth_in_W2))
            scale_t = -bitwidth_t + 1 + int(np.floor(np.log2(6)) + 1)
            scale_shift = scale_t - scale_mul2
            shr[5] = scale_shift

            # Stage 3 Step 7: Multiplication
            bitwidth_u3 = bitwidth_t + bitwidth_in_F3 - 1
            bitwidth_u3_code = self.getTempBitwidth(bitwidth_t, bitwidth_in_F3, "mul")
            scale_u3 = scale_t + scale_in_F3

            # Stage 3 Step 8: Tree Sum
            d3 = int(np.ceil(np.log2(Ct)))
            scale_u3 = scale_u3 + d3

            # Stage 3 Step 9: Batch Normalisation
            bitwidth_add3 = np.max((bitwidth_t, bitwidth_in_F3))
            bitwidth_reduction = config.wordLength - bitwidth_add3
            _, scale_add3 = self.getBitwidthAndScale(expr_out.idf + "t5") + bitwidth_reduction 
            shr[6] = (scale_add3 - scale_u3)
            shr[7] = (scale_add3 - scale_in_B3)
            bitwidth_mul3 = bitwidth_add3 + bitwidth_in_W3 - 1
            bitwidth_mul3_code = self.getTempBitwidth(bitwidth_add3, bitwidth_in_W3, "mul")
            scale_mul3 = scale_add3 + scale_in_W3
            scale_reduction = scale_out - scale_mul3
            shr[8] = scale_reduction

            for i in range(9):
                shl[i] = -shr[i]
        else:
            d1 = int(np.ceil(np.log2(Cin)))
            d2 = int(np.ceil(np.log2(Hf * Wf)))
            d3 = int(np.ceil(np.log2(Ct)))
            # In floating-point mode, none of the following values matter. Setting them to dummy values.
            for i in range(9):
                shr[i] = 1
                shl[i] = 1
            bitwidth_u = bitwidth_t = bitwidth_x = config.wordLength
            bitwidth_u1_code = bitwidth_u2_code = bitwidth_u3_code = config.wordLength
            six1 = six2 = 6.0
            scale_x = scale_t = 0

        for i in range(9):
            if shr[i] >= 0:
                shr[i] = self.formatShr(shr[i], saturate=False)
                shl[i] = self.formatShr(0)
            else:
                shr[i] = self.formatShr(0)
                shl[i] = self.formatShr(shl[i], saturate=False)

        expr_in_A.inputVar = False
        expr_in_F1.inputVar = False
        expr_in_W1.inputVar = False
        expr_in_B1.inputVar = False
        expr_in_F2.inputVar = False
        expr_in_W2.inputVar = False
        expr_in_B2.inputVar = False
        expr_in_F3.inputVar = False
        expr_in_W3.inputVar = False
        expr_in_B3.inputVar = False
        expr_out.inputVar = False
        expr_treeSum.inputVar = False
        expr_bufT.inputVar = False
        expr_bufX.inputVar = False

        bitwidth_u = np.max((bitwidth_u1_code, bitwidth_u2_code, bitwidth_u3_code))

        # Setting metadata.
        if forFixed():
            self.varsForBitwidth[expr_treeSum.idf] = bitwidth_u
            self.varsForBitwidth[expr_bufT.idf] = bitwidth_t
            self.varsForBitwidth[expr_bufX.idf] = bitwidth_x

        comment = IR.Comment('MBconv(%s)' %(expr_in_A.idf), self.counter_inst+1)
        self.allDepths[self.counter_inst+1] = self.curDepth

        argMap = {
            expr_in_A: "A",
            expr_in_F1: "F1",
            expr_in_W1: "BN1W",
            expr_in_B1: "BN1B",
            expr_in_F2: "F2",
            expr_in_W2: "BN2W",
            expr_in_B2: "BN2B",
            expr_in_F3: "F3",
            expr_in_W3: "BN3W",
            expr_in_B3: "BN3B",
            expr_out: "C",
            expr_bufX: "X",
            expr_bufT: "T",
            expr_treeSum: "U",
            IR.Int(N): "N",
            IR.Int(H): "H",
            IR.Int(W): "W",
            IR.Int(Cin): "Cin",
            IR.Int(Ct): "Ct",
            IR.Int(Hf): "HF",
            IR.Int(Wf): "WF",
            IR.Int(Cout): "Cout",
            IR.Int(type_out.shape[1]): "Hout",
            IR.Int(type_out.shape[2]): "Wout",
            IR.Int(node.padding[0]): "HPADL",
            IR.Int(node.padding[1]): "HPADR",
            IR.Int(node.padding[2]): "WPADL",
            IR.Int(node.padding[3]): "WPADR",
            IR.Int(node.stride[0]): "HSTR",
            IR.Int(node.stride[1]): "WSTR",
            IR.Int(d1): "D1",
            IR.Int(d2): "D2",
            IR.Int(d3): "D3",
            IR.Int(six1): "SIX_1",
            IR.Int(six2): "SIX_2",
        }

        for i in range(9):
            argMap[shr[i]] = "shr%d" % (i+1)

        for i in range(9):
            argMap[shl[i]] = "shl%d" % (i+1)

        # These are used to optimise the m3 codegen to club multiple scale modification operators into one for faster code.
        self.biasShifts[expr_in_B1.idf] = int(np.log2(shr[1].n)) - int(np.log2(shl[1].n))
        self.biasShifts[expr_in_B2.idf] = int(np.log2(shr[4].n)) - int(np.log2(shl[4].n))
        self.biasShifts[expr_in_B3.idf] = int(np.log2(shr[7].n)) - int(np.log2(shl[7].n))

        if forFloat():
            argMap[IR.String(expr_out)] = "name"

        # Generating the argument map which is used in the codegen.
        localVarMap = {expr_treeSum.idf: type_treeSum, expr_bufX.idf: type_bufX, expr_bufT.idf: type_bufT}
        if forFloat():
            funcCall = IR.FuncCall("MBConv", argMap) 
        else:
            templateArgs = ("<int%s_t" + (", int%s_t" * 16) + ">") % (bitwidth_in_A, bitwidth_in_F1, bitwidth_in_W1, bitwidth_in_B1, bitwidth_in_F2, bitwidth_in_W2, bitwidth_in_B2, bitwidth_in_F3, bitwidth_in_W3, bitwidth_in_B3, bitwidth_out, bitwidth_x, bitwidth_t, bitwidth_u, bitwidth_mul1_code, bitwidth_mul2_code, bitwidth_mul3_code)
            funcCall = IR.FuncCall("MBConv" + templateArgs, argMap)

        self.counter_inst += 1
        self.updateLiveRange([expr_in_A, expr_in_F1, expr_in_F2, expr_in_F3, expr_in_W1, expr_in_W2, expr_in_W3, expr_in_B1, expr_in_B2, expr_in_B3, expr_out, expr_treeSum, expr_bufX, expr_bufT])

        # Profiling the output variable in floating-point mode for computing the scale of the fixed-point code.
        profile = IR.FuncCall("Profile4", {
            expr_out: "Var",
            IR.Int(N): "I",
            IR.Int(type_out.shape[1]): "J",
            IR.Int(type_out.shape[2]): "K",
            IR.Int(Cout): "L",
            IR.String(expr_out): "VarName"
        })
        if forFloat():
            self.independentVars.append(expr_out.idf)

        prog_mbconv = IR.Prog([comment, funcCall, profile] if forFloat() and self.ddsEnabled else [comment, funcCall])
        prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_F1, prog_in_W1, prog_in_B1, prog_in_F2, prog_in_W2, prog_in_B2, prog_in_F3, prog_in_W3, prog_in_B3, prog_mbconv)

        # Update metadata.
        self.varDeclarations[expr_out.idf] = type_out
        self.varDeclarations[expr_treeSum.idf] = type_treeSum
        self.varDeclarations[expr_bufX.idf] = type_bufX
        self.varDeclarations[expr_bufT.idf] = type_bufT

        self.varScales[expr_out.idf] = scale_out
        self.varScales[expr_treeSum.idf] = 0 # It changes across three stages above and an exact value not required outside of this method.
        self.varScales[expr_bufX.idf] = scale_x
        self.varScales[expr_bufT.idf] = scale_t

        # Intervals not needed necesarily for the compiler to run, updating this variable for being compatible with old SeeDot (PLDI '19).
        self.varIntervals[expr_out.idf] = (0, 0)
        self.varIntervals[expr_treeSum.idf] = (0, 0)
        self.varIntervals[expr_bufX.idf] = (0, 0)
        self.varIntervals[expr_bufT.idf] = (0, 0)

        # Printing log.
        self.log.print(comment.msg)
        self.log.print("\tInput1: scale = %d, interval = [%d, %d]" % (
            (self.varScales[expr_in_A.idf],) + self.varIntervals[expr_in_A.idf]))
        self.log.print("\tOutput: scale = %d, interval = [%d, %d]" % (
            (self.varScales[expr_out.idf],) + self.varIntervals[expr_out.idf]))

        return (prog_out, expr_out)