tools/go-agent/tools/dst.go (361 lines of code) (raw):

// Licensed to Apache Software Foundation (ASF) under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Apache Software Foundation (ASF) licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package tools import ( "bytes" "fmt" "go/ast" "go/printer" "go/token" "io" "os" "path/filepath" "regexp" "strings" "github.com/dave/dst" "github.com/dave/dst/decorator" "github.com/dave/dst/dstutil" ) var packageImportExp = regexp.MustCompile(`^(\S+\s+)?"(.+)"$`) func ChangePackageImportPath(file dst.Node, pkgChanges map[string]string) { dstutil.Apply(file, func(cursor *dstutil.Cursor) bool { if n, ok := cursor.Node().(*dst.ImportSpec); ok { for originalPkg, targetPkg := range pkgChanges { sprintf := fmt.Sprintf("%q", originalPkg) if n.Path.Value == sprintf { n.Path.Value = fmt.Sprintf("%q", targetPkg) } } } return true }, func(cursor *dstutil.Cursor) bool { return true }) } func DeletePackageImports(file dst.Node, imports ...string) { containsDeletedImport := false deletedPackages := make(map[string]string) dstutil.Apply(file, func(cursor *dstutil.Cursor) bool { switch n := cursor.Node().(type) { case *dst.ImportSpec: for _, pkg := range imports { if n.Path.Value == fmt.Sprintf("%q", pkg) { containsDeletedImport = true cursor.Delete() if n.Name != nil { deletedPackages[n.Name.Name] = pkg } else { deletedPackages[filepath.Base(pkg)] = pkg } } } return false case *dst.SelectorExpr: pkgRefName, ok := n.X.(*dst.Ident) if !ok { return true } if _, ok := deletedPackages[pkgRefName.Name]; ok { RemovePackageRef(cursor.Parent(), n, -1) } case *dst.CaseClause: for i, d := range n.List { if sel, ok := d.(*dst.SelectorExpr); ok { pkgRefName, ok := sel.X.(*dst.Ident) if !ok { return true } if _, ok := deletedPackages[pkgRefName.Name]; ok { RemovePackageRef(n, sel, i) } } } } return true }, func(cursor *dstutil.Cursor) bool { return true }) if containsDeletedImport { RemoveImportDefineIfNoPackage(file) } } func RemovePackageRef(parent dst.Node, current *dst.SelectorExpr, inx int) { switch p := parent.(type) { case *dst.Field: p.Type = dst.NewIdent(current.Sel.Name) case *dst.Ellipsis: p.Elt = dst.NewIdent(current.Sel.Name) case *dst.StarExpr: p.X = dst.NewIdent(current.Sel.Name) case *dst.TypeAssertExpr: p.Type = dst.NewIdent(current.Sel.Name) case *dst.CompositeLit: p.Type = dst.NewIdent(current.Sel.Name) case *dst.ArrayType: p.Elt = dst.NewIdent(current.Sel.Name) case *dst.CallExpr: p.Fun = dst.NewIdent(current.Sel.Name) case *dst.KeyValueExpr: p.Value = dst.NewIdent(current.Sel.Name) case *dst.AssignStmt: p.Rhs = []dst.Expr{dst.NewIdent(current.Sel.Name)} case *dst.CaseClause: p.List[inx] = dst.NewIdent(current.Sel.Name) } } func RemoveImportDefineIfNoPackage(file dst.Node) { // remove the import decl if empty dstutil.Apply(file, func(cursor *dstutil.Cursor) bool { if decl, ok := cursor.Node().(*dst.GenDecl); ok && decl.Tok == token.IMPORT && len(decl.Specs) == 0 { cursor.Delete() return false } return true }, func(cursor *dstutil.Cursor) bool { return true }) } type DebugInfo struct { FilePath string Line int CheckOldLine bool } func BuildDSTDebugInfo(srcPath string, file *dst.File) (*DebugInfo, error) { result := &DebugInfo{FilePath: srcPath} if file != nil { fset, f, err := decorator.RestoreFile(file) if err != nil { return nil, err } originalFile, err := os.ReadFile(srcPath) if err != nil { return nil, err } line, err := findFirstNoImportLocation(fset, f, bytes.NewBuffer(originalFile)) if err != nil { return nil, err } result.Line = line result.CheckOldLine = true } else { result.Line = 1 result.CheckOldLine = false } return result, nil } func WriteDSTFile(path string, file *dst.File, debug *DebugInfo) error { output, err := os.Create(path) if err != nil { return err } defer output.Close() content, err := GenerateDSTFileContent(file, debug) if err != nil { return err } if _, err = output.WriteString(content); err != nil { return err } return nil } func GenerateDSTFileContent(file *dst.File, debug *DebugInfo) (string, error) { var buf bytes.Buffer writer := io.Writer(&buf) fset, af, err := decorator.RestoreFile(file) if err != nil { return "", err } if debug != nil { if err1 := writeDSTFileWithDebug(fset, af, debug, writer); err1 != nil { return "", err1 } return buf.String(), nil } if err := printer.Fprint(writer, fset, af); err != nil { return "", err } return buf.String(), nil } func BuildFuncIdentity(pkgPath string, node *dst.FuncDecl) string { var receiver string if node.Recv != nil { expr, ok := node.Recv.List[0].Type.(*dst.StarExpr) if !ok { return "" } ident, ok := expr.X.(*dst.Ident) if !ok { return "" } receiver = ident.Name } return fmt.Sprintf("%s_%s%s", regexp.MustCompile(`[/.\-@]`).ReplaceAllString(pkgPath, "_"), receiver, node.Name) } type ImportAnalyzer struct { imports map[string]map[string]*dst.ImportSpec usedImports map[string]*dst.ImportSpec } func CreateImportAnalyzer() *ImportAnalyzer { return &ImportAnalyzer{ imports: make(map[string]map[string]*dst.ImportSpec), usedImports: make(map[string]*dst.ImportSpec)} } func (i *ImportAnalyzer) AnalyzeFileImports(filePath string, f dst.Node) { imports := make(map[string]*dst.ImportSpec) i.imports[filePath] = imports dstutil.Apply(f, func(cursor *dstutil.Cursor) bool { importSpec, ok := cursor.Node().(*dst.ImportSpec) if !ok { return true } var pkgName = filepath.Base(importSpec.Path.Value) if importSpec.Name != nil { pkgName = importSpec.Name.Name } imports[strings.Trim(pkgName, "\"")] = importSpec return false }, func(cursor *dstutil.Cursor) bool { return true }) } func (i *ImportAnalyzer) AnalyzeNeedsImports(filePath string, fields *dst.FieldList) { if fields == nil || len(fields.List) == 0 { return } for _, f := range fields.List { i.analyzeFieldImport(filePath, f.Type) } } func (i *ImportAnalyzer) analyzeFieldImport(filePath string, exp dst.Expr) { switch n := exp.(type) { case *dst.Ident: return case *dst.SelectorExpr: pkgRefName, ok := n.X.(*dst.Ident) if !ok { return } imports := i.imports[filePath] if imports == nil { return } spec := imports[pkgRefName.Name] if spec == nil { return } i.usedImports[pkgRefName.Name] = spec case *dst.Ellipsis: i.analyzeFieldImport(filePath, n.Elt) case *dst.ArrayType: i.analyzeFieldImport(filePath, n.Elt) case *dst.StarExpr: i.analyzeFieldImport(filePath, n.X) } } func (i *ImportAnalyzer) AppendUsedImports(decl *dst.GenDecl) { if decl.Tok != token.IMPORT { return } for _, spec := range i.usedImports { found := false for _, existingSpec := range decl.Specs { if existingSpec.(*dst.ImportSpec).Path.Value == spec.Path.Value { found = true break } } if !found { decl.Specs = append(decl.Specs, dst.Clone(spec).(*dst.ImportSpec)) } } } func writeDSTFileWithDebug(fset *token.FileSet, file *ast.File, debug *DebugInfo, output io.Writer) error { var changeInfo *dstFilePathChangeInfo if !debug.CheckOldLine { changeInfo = &dstFilePathChangeInfo{ oldDebugPath: debug.FilePath, oldDebugLine: 1, newDebugLine: 1, } if _, err := fmt.Fprintf(output, "//line %s:%d\n", debug.FilePath, debug.Line); err != nil { return err } if err := printer.Fprint(output, fset, file); err != nil { return err } return nil } var buffer bytes.Buffer if err := printer.Fprint(&buffer, fset, file); err != nil { return err } newPosition, err := findFirstNoImportLocation(fset, file, bytes.NewBuffer(buffer.Bytes())) if err != nil { return err } changeInfo = &dstFilePathChangeInfo{ oldDebugPath: debug.FilePath, oldDebugLine: debug.Line, newDebugLine: newPosition, } lineCount := 1 alreadyChange := false for { line, err := buffer.ReadBytes('\n') if err != nil { if err == io.EOF && !alreadyChange { return fmt.Errorf("rewrite file line number failure: %v", err) } break } if lineCount == changeInfo.newDebugLine { line = []byte(fmt.Sprintf("//line %s:%d\n%s", debug.FilePath, changeInfo.oldDebugLine, line)) alreadyChange = true } if _, e := output.Write(line); e != nil { return err } lineCount++ } return nil } type dstFilePathChangeInfo struct { oldDebugPath string oldDebugLine int newDebugLine int } func findFirstNoImportLocation(fset *token.FileSet, file *ast.File, fileContent *bytes.Buffer) (int, error) { var pos token.Pos for _, decl := range file.Decls { if genDecl, ok := decl.(*ast.GenDecl); ok { if genDecl.Tok == token.IMPORT { pos = genDecl.End() continue } } break } importEndLine := fset.Position(pos).Line if pos == 0 { if len(file.Decls) == 0 { return 1, nil } importEndLine = fset.Position(file.Decls[0].Pos()).Line } lineNumber := 0 for { line, err := fileContent.ReadBytes('\n') if err != nil { return 0, err } lineNumber++ if lineNumber < importEndLine { continue } trimed := strings.TrimSpace(string(line)) if trimed == "" || trimed == ")" || (strings.HasPrefix(trimed, "import ")) || (packageImportExp.MatchString(trimed)) { continue } return lineNumber, nil } }