def translateToC()

in tools/SeeDot/seedot/compiler/codegen/m3.py [0:0]


    def translateToC(self, varName, argList):
        varName = varName.replace('<', ' ').replace('>', '').replace(',', '')
        varName = varName.split(' ')
        name = varName[0]
        bitwidths = []
        if name[:7] == "Sigmoid" or name[:4] == "TanH":
            pass
        else:
            for bws in varName[1:]:
                bitwidths.append(int((bws[3:])[:-2]))
        revArgList = {}
        for key, value in argList.items():
            revArgList[value] = key

        assert forFixed(), "Only fixed point code for M3 supported"
        assert config.vbwEnabled, "Function calls for VBW mode only supported on M3"
        
        # Type checking has already been done so no exhaustive checks here
        if name[:-2] == "MatAdd" or name == "MatSub":   # MatAddNC MatAddCN MatAddCC MatAddNN
            shapeB = self.decls[revArgList["B"].idf].shape
            if shapeB[1] == 1:
                op = "add" if name[3:6] == "Add" else "sub"
                assert bitwidths[0] == bitwidths[1] == bitwidths[3]
                if op == "add":
                    assert bitwidths[0] == 16, "Not Implemented for M3"
                funcName = "q%d_v_%s" % (bitwidths[0] - 1, op)
                scret = revArgList["shrC"].n * revArgList["demote"].n
                args = {
                    revArgList["A"] : "vec1",
                    revArgList["B"] : "vec2",
                    revArgList["I"] : "len",
                    revArgList["C"] : "ret",
                    revArgList["shrA"]: "scvec1",
                    revArgList["shrB"]: "scvec2",
                    revArgList["shrC"]: "scret",
                    revArgList["demote"]: "demote"
                } if op == "add" else {
                    revArgList["A"] : "vec1",
                    revArgList["B"] : "vec2",
                    revArgList["J"] : "len",
                    revArgList["C"] : "ret",
                    revArgList["shrA"]: "scvec1",
                    revArgList["shrB"]: "scvec2",
                    IR.Int(scret) : "scret"
                }
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:-1] == "MatAdd": # MatAdd4
            assert bitwidths[0] == bitwidths[1] == bitwidths[3]
            funcName = "q%d_t_add" % (bitwidths[0] - 1)
            scret = revArgList["shrC"].n * revArgList["demote"].n
            args = {
                revArgList["A"] : "ten1",
                revArgList["B"] : "ten2",
                revArgList["N"] : "nbatches",
                revArgList["H"] : "nrows",
                revArgList["W"] : "ncols",
                revArgList["C"] : "nchannels",
                revArgList["X"] : "ret",
                revArgList["shrA"] : "scten1",
                revArgList["shrB"] : "scten2",
                IR.Int(scret) : "scret"
            }
            return funcName, args
        elif name[:6] == "MatAdd" and name[6:-1] == "BroadCast": # MatAddBroadCastA MatAddBroadCastB
            if name[-1] == "A":
                shapeVec = self.decls[revArgList["B"].idf].shape
                vec = revArgList["B"]
                scalar = revArgList["A"]
                scvec = revArgList["shrB"]
                scscalar = revArgList["shrA"]
            elif name[-1] == "B":
                shapeVec = self.decls[revArgList["A"].idf].shape
                vec = revArgList["A"]
                scalar = revArgList["B"]
                scvec = revArgList["shrA"]
                scscalar = revArgList["shrB"]
            else:
                assert False, "Illegal State"
            if shapeVec[1] == 1:
                assert bitwidths[0] == bitwidths[1] == bitwidths[3] == 16
                funcName = "q15_v_scalar_add"
                scret = revArgList["shrC"].n * revArgList["demote"].n
                args = {
                    IR.Int(self.cnsts[scalar.idf]) : "scalar", # scalar : "scalar",
                    vec : "vec",
                    revArgList["I"] : "len",
                    revArgList["C"] : "ret",
                    scscalar : "scscalar",
                    scvec : "scvec",
                    IR.Int(scret) : "scret"
                }
                return funcName, args
            else:
                assert False, "Not implemented for M3"
        elif name[:6] == "MatSub" and name[6:-1] == "BroadCast":
            assert bitwidths[0] == bitwidths[1] == bitwidths[3] == 16, "Not implemented on M3"
            scret = revArgList["shrC"].n * revArgList["demote"].n
            if name[-1] == "A":
                shapeB = self.decls[revArgList["B"].idf].shape
                assert shapeB[1] == 1
                funcName = "q15_v_scalar_sub"
                args = {
                    IR.Int(self.cnsts[revArgList["A"].idf]) : "scalar", # revArgList["A"] : "scalar",
                    revArgList["B"] : "vec",
                    revArgList["I"] : "len",
                    revArgList["C"] : "ret",
                    revArgList["shrA"] : "scscalar",
                    revArgList["shrB"] : "scvec",
                    IR.Int(scret) : "scret"
                }
            elif name[-1] == "B":
                assert False, "Not implemented on M3"
            return funcName, args
        elif name[:-2] == "AddOrSubCir": # AddOrSubCir2D AddOrSubCir4D
            addOrSub = "add" if revArgList["add"].b else "sub"
            bwA = bitwidths[0]
            bwB = bitwidths[1]
            bwX = bitwidths[3]
            dim = 2 if name[-2:] == "2D" else 4
            scret = revArgList["shrC"].n * revArgList["demote"].n
            if dim == 2:
                assert False, "Not implemented for M3"
            else:
                args = {
                    revArgList["A"] : "mat",
                    revArgList["B"] : "vec",
                    revArgList["N"] : "nbatches",
                    revArgList["H"] : "nrows",
                    revArgList["W"] : "ncols",
                    revArgList["C"] : "nchannels",
                    revArgList["X"] : "ret",
                    revArgList["shrA"] : "scmat",
                    revArgList["shrB"] : "scvec",
                    IR.Int(scret) : "scret"
                }
                if bwA == bwB == bwX == 16:
                    bwString = "q15_t"
                elif bwA == bwX == 8 and bwB == 16 and addOrSub == "add":
                    bwString = "q7xq15_q7_t"
                else:
                    assert False, "Not implemented for M3"
            return ("%s_%s_vec" % (bwString, addOrSub)), args
        elif name == "MulCir": # MulCir
            shapeB = self.decls[revArgList["B"].idf].shape
            if shapeB[1] == 1:
                assert bitwidths[0] == bitwidths[1] == bitwidths[3]
                funcName = "q%d_v_hadamard" % (bitwidths[0] - 1)
                scvec2 = revArgList["shrB"].n * revArgList["demote"].n
                args = {
                    revArgList["A"] : "vec1",
                    revArgList["B"] : "vec2",
                    revArgList["I"] : "len",
                    revArgList["C"] : "ret",
                    revArgList["shrA"]: "scvec1",
                    IR.Int(scvec2) : "scvec2"
                }
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:7] == "Sigmoid": # Sigmoid SigmoidNew16
            shapeA = self.decls[revArgList["A"].idf].shape
            use_tables = useNewTableExp() or useMathExp()
            if shapeA[1] == 1:
                assert self.varsForBitwidth[revArgList["A"].idf] == 16
                funcName = "q15_v_sigmoid"
                args = {
                    revArgList["A"] : "vec",
                    revArgList["I"] : "len",
                    revArgList["B"] : "ret",
                    revArgList.get("div", IR.Int(0)) : "div",
                    revArgList.get("add", IR.Int(0)) : "add",
                    revArgList.get("sigmoid_limit", IR.Int(0)) : "sigmoid_limit",
                    revArgList.get("scale_in", IR.Int(0)) : "scale_in",
                    revArgList.get("scale_out", IR.Int(0)) : "scale_out",
                    IR.Bool(use_tables) : "use_tables"
                }
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:4] == "TanH": # TanH TanHNew16
            shapeA = self.decls[revArgList["A"].idf].shape
            use_tables = useNewTableExp() or useMathExp()
            if shapeA[1] == 1:
                assert self.varsForBitwidth[revArgList["A"].idf] == 16
                funcName = "q15_v_tanh"
                args = {
                    revArgList["A"] : "vec",
                    revArgList["I"] : "len",
                    revArgList["B"] : "ret",
                    revArgList.get("scale_in", IR.Int(0)) : "scale_in",
                    revArgList.get("scale_out", IR.Int(0)) : "scale_out",
                    IR.Bool(use_tables) : "use_tables"
                }
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:3] == "Exp": # Exp ExpNew16
            assert False, "Not Implemented for M3"
        elif name[:11] == "AdjustScale": # AdjustScaleShl AdjustScaleShr AdjustScaleShlSaturate
            if name[-8:] == "Saturate":
                assert False, "Not implemented for M3"
            assert bitwidths[0] == 16, "Not implemented for M3"
            shapeA = self.decls[revArgList["A"].idf].shape
            if name[-3:] == "Shl":
                assert len(shapeA) == 2, "Not implemented for M3"
            if shapeA[1] == 1:
                funcName = "q15_v_scale_%s" % ("up" if name[-3:] == "Shl" else "down")
                ret = IR.Var(revArgList["A"].idf)
                ret.inputVar = revArgList["A"].inputVar
                ret.internalVar = revArgList["A"].internalVar
                args = {
                    revArgList["A"] : "vec",
                    revArgList["I"] : "len",
                    ret : "ret",
                    revArgList["scale"] : "scvec"
                }
                return funcName, args
        elif name == "Transpose": # Transpose
            assert False, "Not implemented for M3"
        elif name == "Reverse2": # Reverse
            assert bitwidths[0] == 16, "Not implemented for M3"
            funcName = "q15_m_reverse"
            args = {
                revArgList["A"] : "mat",
                revArgList["I"] : "nrows",
                revArgList["J"] : "ncols",
                revArgList["axis"] : "axis",
                revArgList["B"] : "ret"
            }
            return funcName, args
        elif name == "ScalarMul": # ScalarMul
            shapeB = self.decls[revArgList["B"].idf].shape
            if shapeB[1] == 1:
                assert bitwidths[0] == bitwidths[1] == bitwidths[3] == 16
                funcName = "q15_v_scalar_mul"
                scvec = revArgList["shrB"].n * revArgList["demote"].n
                args = {
                    IR.Int(self.cnsts[revArgList["A"].idf]) : "scalar", # revArgList["A"] : "scalar",
                    revArgList["B"] : "vec",
                    revArgList["I"] : "len",
                    revArgList["C"] : "ret",
                    revArgList["shrA"]: "scscalar",
                    IR.Int(scvec) : "scvec"
                }
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:6] == "MatMul": # MatMulNN MatMulNC MatMulCC MatMulCN
            shapeB = self.decls[revArgList["B"].idf].shape
            if shapeB[1] == 1:
                bwA = bitwidths[0]
                bwB = bitwidths[1]
                bwC = bitwidths[3]
                scvec = IR.Int(revArgList["shrA"].n * revArgList["demote"].n)
                shrB = IR.Int(revArgList["shrB"].n * revArgList["H1"].n)
                args = {
                    revArgList["A"] : "mat",
                    revArgList["B"] : "vec",
                    revArgList["I"] : "nrows",
                    revArgList["J"] : "ncols",
                    revArgList["C"] : "ret",
                    shrB : "scmat", # revArgList["shrB"]: "scmat",
                    scvec : "scvec",
                    revArgList["H1"] : "scret",
                }
                if bwA == bwB == bwC == 16: # Note the order of inputs is reversed.
                    bwString = "q15"
                elif bwA == bwC == 16 and bwB == 8:
                    bwString = "q15xq7_q15"
                else:
                    assert False, "Not implemented for M3"
                funcName = "%s_m_mulvec" % bwString
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name[:12] == "SparseMatMul": # SparseMatMul
            shapeB = self.decls[revArgList["B"].idf].shape
            assert revArgList["B"].idf != "X", "Sparse MatMul for X not supported on M3"
            if shapeB[1] == 1:
                bwA = bitwidths[0]
                bwB = bitwidths[2]
                bwC = bitwidths[4]
                shrC = IR.Int(revArgList["shrC"].n * revArgList["demote"].n)
                args = {
                    revArgList["Aidx"] : "row_indices",
                    revArgList["Aval"] : "mat_values",
                    revArgList["B"] : "vec",
                    revArgList["K"] : "nelem",
                    revArgList["C"] : "ret",
                    revArgList["shrA"] : "scmat",
                    revArgList["shrB"] : "scvec",
                    shrC : "scret",
                }
                if bwA == bwB == bwC == 16: # Note the order of inputs is reversed.
                    bwString = "q15"
                elif bwA == bwC == 16 and bwB == 8:
                    bwString = "q15xq7_q15"
                else:
                    assert False, "Not implemented for M3"
                funcName = "%s_m_sparse_mulvec" % bwString
                return funcName, args
            else:
                assert False, "Not Implemented for M3"
        elif name == "ArgMax": # ArgMax
            shapeA = self.decls[revArgList["A"].idf].shape
            if shapeA[1] == 1:
                assert bitwidths[0] == 16, "Not implemented for M3"
                funcName = "q15_v_argmax"
                args = {
                    revArgList["A"] : "vec",
                    revArgList["I"] : "len",
                    revArgList["index"] : "ret"
                }
                return funcName, args
            else:
                assert False, "Not implemented for M3"
        elif name[:4] == "Relu": # Relu2D Relu4D Relu6
            if name[-2:] == "4D" or name[-2:] == "2D":
                assert False, "Not implemented for M3"
            elif name[-1] == "6":
                assert bitwidths[0] == 8, "Not implemented for M3"
                funcName = "q7_t_relu"
                args = {
                    revArgList["A"] : "vec",
                    revArgList["N"] : "nbatches",
                    revArgList["H"] : "nrows",
                    revArgList["W"] : "ncols",
                    revArgList["C"] : "nchannels",
                    revArgList["B"] : "ret",
                    revArgList["six"] : "limit",
                    revArgList["div"] : "div"
                }
                return funcName, args
            else:
                assert False, "Not implemented for M3"
        elif name == "NormaliseL2": # NormaliseL2
            assert bitwidths[0] == 16, "Not implemented for M3"
            funcName = "q15_t_l2_norm"
            args = {
                revArgList["A"] : "ten",
                revArgList["N"] : "nbatches",
                revArgList["H"] : "nrows",
                revArgList["W"] : "ncols",
                revArgList["C"] : "nchannels",
                revArgList["B"] : "ret",
                revArgList["scaleA"] : "scale_in",
                revArgList["shrA"] : "scale_out"
            }
            return funcName, args
        elif name == "Maxpool": # Maxpool
            assert False, "Not implemented for M3"
        elif name == "Convolution": # Convolution
            bwA = bitwidths[0]
            bwB = bitwidths[1]
            bwC = bitwidths[3]
            shrA = IR.Int(revArgList["shrA"].n * 2 ** revArgList["H1"].n)
            args = {
                revArgList["A"] : "input",
                revArgList["B"] : "filter",
                revArgList["C"] : "output",
                # revArgList["tmp"] : "treesumBuffer",
                revArgList["N"] : "N",
                revArgList["H"] : "H",
                revArgList["W"] : "W",
                revArgList["CIN"] : "CIn",
                revArgList["HF"] : "HF",
                revArgList["WF"] : "WF",
                revArgList["CINF"] : "CF",
                revArgList["COUTF"] : "COut",
                revArgList["HOUT"] : "HOut",
                revArgList["WOUT"] : "WOut",
                revArgList["G"] : "G",
                revArgList["HPADL"] : "HPadU",
                revArgList["HPADR"] : "HPadD",
                revArgList["WPADL"] : "WPadL",
                revArgList["WPADR"] : "WPadR",
                revArgList["HSTR"] : "HStride",
                revArgList["WSTR"] : "WStride",
                revArgList["HDL"] : "HDilation",
                revArgList["WDL"] : "WDilation",
                # revArgList["H1"] : "H1",
                # revArgList["H2"] : "H2",
                shrA : "scinput", # revArgList["shrA"] : "scinput",
                revArgList["shrB"] : "scoutput",
                revArgList["demote"] : "demote"
            }
            if bwA == bwB == bwC == 16:
                bwString = "q15"
            elif bwA == 8 and bwB == 16:
                bwString = "q7xq15_q%d" % (bwC - 1)
            else:
                assert False, "Not implemented for M3"
            return "%s_convolution" % bwString, args
        elif name == "MBConv":
            # TODO: Remove the TreeSum buffer variable from the list of variables to be allocated from the scratch space, because the TreeSum variable is not used in the M3 codegen.
            bwA = bitwidths[0]
            bwF1 = bitwidths[1]
            bwW1 = bitwidths[2]
            bwB1 = bitwidths[3]
            bwF2 = bitwidths[4]
            bwW2 = bitwidths[5]
            bwB2 = bitwidths[6]
            bwF3 = bitwidths[7]
            bwW3 = bitwidths[8]
            bwB3 = bitwidths[9]
            assert bwF1 == bwW1 == bwB1 == bwF2 == bwW2 == bwB2 == bwF3 == bwW3 == bwB3, "Not implemented for M3"
            bwB = bwF1
            bwC = bitwidths[10]
            shr1 = IR.Int(revArgList["shr1"].n * 2 ** revArgList["D1"].n)
            shr4 = IR.Int(revArgList["shr4"].n * 2 ** revArgList["D2"].n)
            shr7 = IR.Int(revArgList["shr7"].n * 2 ** revArgList["D3"].n)
            args = {
                revArgList["A"] : "input",
                revArgList["F1"] : "filter1",
                revArgList["BN1W"] : "BN1W",
                revArgList["BN1B"] : "BN1B",
                revArgList["F2"] : "filter2",
                revArgList["BN2W"] : "BN2W",
                revArgList["BN2B"] : "BN2B",
                revArgList["F3"] : "filter3",
                revArgList["BN3W"] : "BN3W",
                revArgList["BN3B"] : "BN3B",
                revArgList["C"] : "output",
                revArgList["X"] : "convBuffer1",
                revArgList["T"] : "convBuffer2",
                revArgList["N"] : "N",
                revArgList["H"] : "H",
                revArgList["W"] : "W",
                revArgList["Cin"] : "CIn",
                revArgList["Ct"] : "CTemp",
                revArgList["HF"] : "HF",
                revArgList["WF"] : "WF",
                revArgList["Cout"] : "COut",
                revArgList["Hout"] : "HOut",
                revArgList["Wout"] : "WOut",
                revArgList["HPADL"] : "HPadU",
                revArgList["HPADR"] : "HPadD",
                revArgList["WPADL"] : "WPadL",
                revArgList["WPADR"] : "WPadR",
                revArgList["HSTR"] : "HStride",
                revArgList["WSTR"] : "WStride",
                revArgList["SIX_1"] : "limit1",
                revArgList["SIX_2"] : "limit2",
                shr1 : "shrU1",
                revArgList["shr3"] : "shrX1",
                shr4 : "shrU2",
                revArgList["shr6"] : "shrX2",
                shr7 : "shrU3",
                revArgList["shr9"] : "shrX3",
                revArgList["shl1"] : "shlU1",
                revArgList["shl3"] : "shlX1",
                revArgList["shl4"] : "shlU2",
                revArgList["shl6"] : "shlX2",
                revArgList["shl7"] : "shlU3",
                revArgList["shl9"] : "shlX3",
            }
            if bwA == bwB == bwC:
                bwString = "q%d" % (bwA - 1)
            elif bwA == 8 and bwB == bwC == 16:
                bwString = "q7xq15_q15"
            elif bwA == 16 and bwB == 8:
                bwString = "q15xq7_q%d" % (bwC - 1)
            else:
                assert False, "Not implemented for M3"
            return "%s_mbconv_block" % bwString, args
        else:
            assert False, "Not implemented for M3"