def visitBop2()

in tools/SeeDot/seedot/compiler/ir/irBuilder.py [0:0]


    def visitBop2(self, node: AST.Bop2):
        (prog_in_A, expr_in_A) = self.visit(node.expr1)

        (prog_in_B, expr_in_B) = self.visit(node.expr2)

        type_out = node.type

        if node.op == SeeDotParser.ADD:
            (op_ir, op_fn) = (IR.Op.Op['+'], operator.add)
            funcName = "MatAdd"
        elif node.op == SeeDotParser.SUB:
            (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub)
            funcName = "MatSub"

        # e : Int
        if Type.isInt(type_out):
            prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_B)
            expr_out = IR.IntBop(expr_in_A, op_ir, expr_in_B)

            # Just to be safe that the scaling factor of the integer variable is never tracked.
            if isinstance(expr_in_A, IR.Var):
                assert expr_in_A.idf not in self.varScales and expr_in_A.idf not in self.varIntervals
            if isinstance(expr_in_B, IR.Var):
                assert expr_in_B.idf not in self.varScales and expr_in_B.idf not in self.varIntervals
        # e : Tensor(), or Tensor(..)
        else:
            assert type_out.dim == 2 or (type_out.dim == 4 and config.vbwEnabled), "Addition/subtraction of tensors is currently only supported for 2D tensors. Addition for 4D tensors is supported when VBW is enabled"

            type_A = node.expr1.type
            type_B = node.expr2.type

            assert (not type_out.dim == 4) or (type_A.dim == type_B.dim and expr_in_A.idf not in self.globalVars and expr_in_B.idf not in self.globalVars and node.op == SeeDotParser.ADD), "For 4D operation, no broadcasting supported, inputs should not be model parameters, and operation cannot be subtraction"

            # Depending on whether one of the inputs is a model parameter, change the function name so that the model parameter is read differently in the arduino codegen. No difference in case of x86 code.
            c = ''
            if op_fn == operator.add:
                if expr_in_A.idf in self.globalVars:
                    c += 'C'
                else:
                    c += 'N'
                if expr_in_B.idf in self.globalVars:
                    c += 'C'
                else:
                    c += 'N'

            # If one of the inputs is a scalar, the operator will broadcast that input.
            if type_A.dim == 0:
                funcName += 'BroadCastA'
                c = ''
            elif type_B.dim == 0:
                funcName += 'BroadCastB'
                c = ''

            expr_out = self.getTempVar()

            # Read input scale.
            bitwidth_in_A, scale_in_A = self.getBitwidthAndScale(expr_in_A.idf)
            bitwidth_in_B, scale_in_B = self.getBitwidthAndScale(expr_in_B.idf)
            # Read output scale.
            if self.ddsEnabled:
                bitwidth_out, scale_out = self.getBitwidthAndScale(expr_out.idf)
            else:
                bitwidth_out = config.wordLength // 2 if expr_out.idf in self.demotedVarsList else config.wordLength
                scale_out = None

            # Compute scaling hyperparameters given input and output scales. If static scaling of old SeeDot is used, also compute the output scale and bit-width.
            (scale_out_unadjusted, intv_out, [shr_A, shr_B, shr_out]) = self.getScaleForAddAndSub(scale_in_A, scale_in_B, scale_out, op_fn)
            if scale_out is None:
                    scale_out = scale_out_unadjusted

            intv_in_A, intv_in_B = self.varIntervals[expr_in_A.idf], self.varIntervals[expr_in_B.idf]

            demoteLog = shr_out - 8 if shr_out >= 8 else 0
            shr_out = min(shr_out, 8)
            irdemote = self.formatShr(demoteLog)

            if type_out.dim == 2:
                [I, J] = type_out.shape
            elif type_out.dim == 4:
                [N, H, W, C] = type_out.shape
            else:
                assert False, "Unsupported dimension for addition"

            shr_A = self.formatShr(shr_A)
            shr_B = self.formatShr(shr_B)
            shr_out = self.formatShr(shr_out)

            expr_in_A.inputVar = False
            expr_in_B.inputVar = False
            expr_out.inputVar = False

            comment = IR.Comment(expr_in_A.idf + ' ' +
                                 op_ir.name + ' ' + expr_in_B.idf, self.counter_inst+1)
            self.allDepths[self.counter_inst+1] = self.curDepth

            # Generate output function call depending on dimensionality of the input / output.
            if type_out.dim == 2:
                funcCall = IR.FuncCall(funcName + c, {
                    expr_in_A: "A",
                    expr_in_B: "B",
                    expr_out: "C",
                    IR.Int(I): "I",
                    IR.Int(J): "J",
                    shr_A: "shrA",
                    shr_B: "shrB",
                    shr_out: "shrC"
                }) if not self.vbwEnabled else IR.FuncCall(funcName + c + ("<int%d_t, int%d_t, int%d_t, int%d_t>" % (bitwidth_in_A, bitwidth_in_B, self.getTempBitwidth(bitwidth_in_A, bitwidth_in_B, "add", bitwidth_out), bitwidth_out)), {
                    expr_in_A: "A",
                    expr_in_B: "B",
                    expr_out: "C",
                    IR.Int(I): "I",
                    IR.Int(J): "J",
                    shr_A: "shrA",
                    shr_B: "shrB",
                    shr_out: "shrC",
                    irdemote: "demote"
                })
            elif type_out.dim == 4:
                funcCall = IR.FuncCall(funcName + "4", {
                    expr_in_A: "A",
                    expr_in_B: "B",
                    expr_out: "X",
                    IR.Int(N): "N",
                    IR.Int(H): "H",
                    IR.Int(W): "W",
                    IR.Int(C): "C",
                    shr_A: "shrA",
                    shr_B: "shrB",
                    shr_out: "shrC",
                }) if not self.vbwEnabled else IR.FuncCall(funcName + "4" + ("<int%d_t, int%d_t, int%d_t, int%d_t>" % (bitwidth_in_A, bitwidth_in_B, self.getTempBitwidth(bitwidth_in_A, bitwidth_in_B, "add", bitwidth_out), bitwidth_out)), {
                    expr_in_A: "A",
                    expr_in_B: "B",
                    expr_out: "X",
                    IR.Int(N): "N",
                    IR.Int(H): "H",
                    IR.Int(W): "W",
                    IR.Int(C): "C",
                    shr_A: "shrA",
                    shr_B: "shrB",
                    shr_out: "shrC",
                    irdemote: "demote"
                })

            self.counter_inst += 1
            self.updateLiveRange([expr_in_A, expr_in_B, expr_out])

            if type_out.dim == 4:
                if expr_in_A.idf not in self.globalVars and bitwidth_in_A == bitwidth_out:
                    self.setMemorySharableVariables(expr_in_A, expr_out)
                elif expr_in_B.idf not in self.globalVars and bitwidth_in_B == bitwidth_out:
                    self.setMemorySharableVariables(expr_in_B, expr_out)

            # Profile the output variable in the floating point version to obtain fixed-point scale.
            if type_out.dim == 2:
                profile = IR.FuncCall("Profile2", {
                                expr_out: "Var",
                                IR.Int(I): "I",
                                IR.Int(J): "J",
                                IR.String(expr_out): "VarName"
                                })
            elif type_out.dim == 4:
                profile = IR.FuncCall("Profile4", {
                                expr_out: "Var",
                                IR.Int(N): "N",
                                IR.Int(H): "H",
                                IR.Int(W): "W",
                                IR.Int(C): "C",
                                IR.String(expr_out): "VarName"
                                })
            else:
                assert False, "Illegal number of dimensions"

            if forFloat():
                self.independentVars.append(expr_out.idf)

            # The theoretical output scale in scale_raw might be different than profiled scale scale_out.
            # We perform a scale adjustment in this case for correctness.
            # TODO: Introduce a post-processing pass to merge consecutive scale adjustments hence generated.
            if type_out.dim == 2:
                adjust = []
                if forFixed():
                    if scale_out_unadjusted != scale_out:
                        if scale_out_unadjusted > scale_out:
                            diff_scale = 2 ** (scale_out_unadjusted - scale_out)
                            adjust = [IR.FuncCall("AdjustScaleShl" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
                                                expr_out: "A",
                                                IR.Int(I): "I",
                                                IR.Int(J): "J",
                                                IR.Int(diff_scale): "scale"
                                                })]
                        elif scale_out_unadjusted < scale_out:
                            diff_scale = 2 ** (scale_out - scale_out_unadjusted)
                            adjust = [IR.FuncCall("AdjustScaleShr" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
                                                expr_out: "A",
                                                IR.Int(I): "I",
                                                IR.Int(J): "J",
                                                IR.Int(diff_scale): "scale"
                                                })]
            elif type_out.dim == 4:
                adjust = []
                if forFixed():
                    if scale_out_unadjusted != scale_out:
                        if scale_out_unadjusted > scale_out:
                            diff_scale = 2 ** (scale_out_unadjusted - scale_out)
                            adjust = [IR.FuncCall("AdjustScaleShl" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
                                                expr_out: "A",
                                                IR.Int(N): "N",
                                                IR.Int(H): "H",
                                                IR.Int(W): "W",
                                                IR.Int(C): "C",
                                                IR.Int(diff_scale): "scale"
                                                })]
                        elif scale_out_unadjusted < scale_out:
                            diff_scale = 2 ** (scale_out - scale_out_unadjusted)
                            adjust = [IR.FuncCall("AdjustScaleShr" + (("<int%d_t>"%bitwidth_out) if self.vbwEnabled else ""), {
                                                expr_out: "A",
                                                IR.Int(N): "N",
                                                IR.Int(H): "H",
                                                IR.Int(W): "W",
                                                IR.Int(C): "C",
                                                IR.Int(diff_scale): "scale"
                                                })]
            else:
                assert False, "Illegal number of dimensions"

            prog_bop = IR.Prog([comment, funcCall, profile] if forFloat() and self.ddsEnabled else [comment, funcCall] + adjust)

            prog_out = IRUtil.concatPrograms(prog_in_A, prog_in_B, prog_bop)

            # Updating metadata.
            self.varDeclarations[expr_out.idf] = type_out
            self.varScales[expr_out.idf] = scale_out
            self.varIntervals[expr_out.idf] = intv_out

            # Print log.
            self.log.print(comment.msg)
            self.log.print("\tInput1: scale = %d, interval = [%d, %d]" % (
                (self.varScales[expr_in_A.idf],) + self.varIntervals[expr_in_A.idf]))
            self.log.print("\tInput2: scale = %d, interval = [%d, %d]" % (
                (self.varScales[expr_in_B.idf],) + self.varIntervals[expr_in_B.idf]))
            self.log.print("\tOutput: scale = %d, interval = [%d, %d]" % (
                (self.varScales[expr_out.idf],) + self.varIntervals[expr_out.idf]))

        return (prog_out, expr_out)