tools/go-agent/instrument/plugins/rewrite/func.go (241 lines of code) (raw):

// Licensed to Apache Software Foundation (ASF) under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Apache Software Foundation (ASF) licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package rewrite import ( "fmt" "reflect" "strings" "github.com/dave/dst" "github.com/dave/dst/dstutil" "github.com/apache/skywalking-go/tools/go-agent/instrument/consts" "github.com/apache/skywalking-go/tools/go-agent/tools" ) var ( GlobalOperatorRealSetMethodName = VarPrefix + "OperatorSetOperator" GlobalOperatorRealGetMethodName = VarPrefix + "OperatorGetOperator" GlobalOperatorRealAppendTracerInitNotify = VarPrefix + "OperatorAppendInitNotify" MetricsRegisterAppender = VarPrefix + "OperatorMetricsAppender" MetricsCollectAppender = VarPrefix + "OperatorMetricsCollectAppender" GlobalOperatorTypeName = TypePrefix + "OperatorOperator" ) func (c *Context) isInitFunc(funcDecl *dst.FuncDecl) bool { return funcDecl.Name.Name == "init" && (funcDecl.Type.Params == nil || len(funcDecl.Type.Params.List) == 0) && (funcDecl.Type.Results == nil || len(funcDecl.Type.Results.List) == 0) } func (c *Context) Func(funcDecl *dst.FuncDecl, cursor *dstutil.Cursor) { // only the static method needs rewrite if funcDecl.Recv == nil { // if the method name is generated, then ignore to enhance(for adapter) if !strings.HasPrefix(funcDecl.Name.Name, GenerateMethodPrefix) && !c.isInitFunc(funcDecl) { prefix := StaticMethodPrefix if tools.ContainsDirective(funcDecl, consts.DirectivePublic) { prefix = c.titleCase.String(GenerateMethodPrefix) } funcDecl.Name = dst.NewIdent(fmt.Sprintf("%s%s%s", prefix, c.currentPackageTitle, funcDecl.Name.Name)) } } else if len(funcDecl.Recv.List) == 1 { // if contains the receiver, then enhance the receiver type field := funcDecl.Recv.List[0] var typeName string if len(field.Names) > 0 { typeName = field.Names[0].Name c.rewriteMapping.addVarMapping(typeName, typeName) } if k, v := c.enhanceTypeNameWhenRewrite(field.Type, field, -1); k != "" { c.rewriteMapping.addTypeMapping(k, v) } } c.initFunctionDetector(funcDecl) // enhance method parameter and return value c.enhanceFuncParameter(funcDecl.Type.Params) c.enhanceFuncParameter(funcDecl.Type.Results) // enhance the method body for _, stmt := range funcDecl.Body.List { c.enhanceFuncStmt(stmt) } } func (c *Context) initFunctionDetector(f *dst.FuncDecl) { initFunc := tools.FindDirective(f, consts.DirectiveInit) if initFunc != "" { if f.Recv != nil && len(f.Recv.List) > 0 { panic("init function should not have receiver") } c.appendInitFunction(f.Name.Name) } } // nolint func (c *Context) enhanceFuncStmt(stmt dst.Stmt) { // for the variables created in the sub statement, ex: if, func(), the temporary variable count should be recorded subCallTypes := []reflect.Type{ reflect.TypeOf(&dst.IfStmt{}), reflect.TypeOf(&dst.BlockStmt{}), reflect.TypeOf(&dst.TypeSwitchStmt{}), } dstutil.Apply(stmt, func(cursor *dstutil.Cursor) bool { for _, t := range subCallTypes { if reflect.TypeOf(cursor.Node()) == t { c.rewriteMapping.pushBlockStack() } } switch n := cursor.Node().(type) { case *dst.BlockStmt: for _, tmp := range n.List { c.enhanceFuncStmt(tmp) } case *dst.AssignStmt: for _, l := range n.Lhs { if k, v := c.enhanceVarNameWhenRewrite(l); k != "" { c.rewriteMapping.addVarMapping(k, v) } } for i, r := range n.Rhs { if k, v := c.enhanceTypeNameWhenRewrite(r, nil, i); k != "" { c.rewriteMapping.addTypeMapping(k, v) } } case *dst.BinaryExpr: c.rewriteVarIfExistingMapping(n.X, n) c.rewriteVarIfExistingMapping(n.Y, n) case *dst.CallExpr: c.enhanceTypeNameWhenRewrite(n.Fun, n, -1) for inx, arg := range n.Args { c.enhanceTypeNameWhenRewrite(arg, n, inx) } case *dst.ReturnStmt: for inx, arg := range n.Results { c.enhanceTypeNameWhenRewrite(arg, n, inx) } case *dst.FuncType: c.enhanceFuncParameter(n.Params) c.enhanceFuncParameter(n.Results) case *dst.ExprStmt: c.enhanceTypeNameWhenRewrite(n.X, n, -1) case *dst.TypeAssertExpr: c.enhanceTypeNameWhenRewrite(n.X, n, -1) c.enhanceTypeNameWhenRewrite(n.Type, n, -1) case *dst.IfStmt: c.enhanceFuncStmt(n.Init) c.enhanceTypeNameWhenRewrite(n.Cond, n, -1) if n.Body != nil { for _, stmt := range n.Body.List { c.enhanceFuncStmt(stmt) } } if n.Else != nil { c.enhanceFuncStmt(n.Else) } case *dst.RangeStmt: c.enhanceTypeNameWhenRewrite(n.X, n, -1) if k, v := c.enhanceVarNameWhenRewrite(n.Key); k != "" { c.rewriteMapping.addVarMapping(k, v) } if k, v := c.enhanceVarNameWhenRewrite(n.Value); k != "" { c.rewriteMapping.addVarMapping(k, v) } if n.Body != nil { for _, stmt := range n.Body.List { c.enhanceFuncStmt(stmt) } } case *dst.ValueSpec: for _, n := range n.Names { if k, v := c.enhanceVarNameWhenRewrite(n); k != "" { c.rewriteMapping.addVarMapping(k, v) } } c.enhanceTypeNameWhenRewrite(n.Type, n, -1) for _, subVal := range n.Values { c.enhanceTypeNameWhenRewrite(subVal, n, -1) } case *dst.TypeSwitchStmt: c.enhanceFuncStmt(n.Init) c.enhanceFuncStmt(n.Assign) if n.Body != nil { for _, stmt := range n.Body.List { c.enhanceFuncStmt(stmt) } } case *dst.SwitchStmt: c.enhanceFuncStmt(n.Init) c.enhanceTypeNameWhenRewrite(n.Tag, n, -1) if n.Body != nil { for _, stmt := range n.Body.List { c.enhanceFuncStmt(stmt) } } case *dst.CaseClause: for i, stmt := range n.List { c.enhanceTypeNameWhenRewrite(stmt, n, i) } for _, stmt := range n.Body { c.enhanceFuncStmt(stmt) } default: return true } return false }, func(cursor *dstutil.Cursor) bool { // all templates variables should be removed for _, t := range subCallTypes { if reflect.TypeOf(cursor.Node()) == t { c.rewriteMapping.popBlockStack() break } } return true }) } // nolint func (c *Context) rewriteVarIfExistingMapping(exp, parent dst.Expr) bool { switch n := exp.(type) { case *dst.Ident: if v := c.rewriteMapping.findVarMappingName(n.Name); v != "" { n.Name = v return true } case *dst.SelectorExpr: if pkg, ok := n.X.(*dst.Ident); ok { if imp := c.packageImport[pkg.Name]; imp != nil { tools.RemovePackageRef(parent, n, -1) return true } } return c.rewriteVarIfExistingMapping(n.X, n) case *dst.CompositeLit: c.enhanceTypeNameWhenRewrite(n.Type, n, -1) for _, elt := range n.Elts { // for struct data, ex: "&xxx{k: v}" if kv, ok := elt.(*dst.KeyValueExpr); ok { c.rewriteVarIfExistingMapping(kv.Value, elt) } } case *dst.UnaryExpr: c.enhanceTypeNameWhenRewrite(n.X, n, -1) case *dst.IndexExpr: c.rewriteVarIfExistingMapping(n.Index, n) c.rewriteVarIfExistingMapping(n.X, n) case *dst.CallExpr: c.enhanceTypeNameWhenRewrite(n.Fun, n, -1) for _, arg := range n.Args { c.rewriteVarIfExistingMapping(arg, n) } case *dst.StarExpr: c.enhanceTypeNameWhenRewrite(n.X, n, -1) case *dst.FuncLit: c.rewriteMapping.pushBlockStack() c.enhanceFuncParameter(n.Type.Params) c.enhanceFuncParameter(n.Type.Results) if n.Body != nil && len(n.Body.List) > 0 { for _, stmt := range n.Body.List { c.enhanceFuncStmt(stmt) } } c.rewriteMapping.popBlockStack() } return false } func (c *Context) enhanceFuncParameter(fields *dst.FieldList) { if fields == nil { return } for _, field := range fields.List { if len(field.Names) > 0 { for inx := range field.Names { name := field.Names[inx].Name // keep the var names for debugging c.rewriteMapping.addVarMapping(name, name) } } if k, v := c.enhanceTypeNameWhenRewrite(field.Type, field, -1); k != "" { c.rewriteMapping.addTypeMapping(k, v) } } }