tool/util/ast.go (417 lines of code) (raw):
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import (
"fmt"
"go/parser"
"go/token"
"os"
"path/filepath"
"regexp"
"github.com/alibaba/opentelemetry-go-auto-instrumentation/tool/errc"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
)
const (
IdentNil = "nil"
IdentTrue = "true"
IdentFalse = "false"
IdentIgnore = "_"
)
// AST Construction
func AddressOf(expr dst.Expr) *dst.UnaryExpr {
return &dst.UnaryExpr{Op: token.AND, X: dst.Clone(expr).(dst.Expr)}
}
func CallTo(name string, args []dst.Expr) *dst.CallExpr {
return &dst.CallExpr{
Fun: &dst.Ident{Name: name},
Args: args,
}
}
func MakeUnusedIdent(ident *dst.Ident) *dst.Ident {
ident.Name = IdentIgnore
return ident
}
func IsUnusedIdent(ident *dst.Ident) bool {
return ident.Name == IdentIgnore
}
func Ident(name string) *dst.Ident {
return &dst.Ident{
Name: name,
}
}
func StringLit(value string) *dst.BasicLit {
return &dst.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", value),
}
}
func IsStringLit(expr dst.Expr, val string) bool {
lit, ok := expr.(*dst.BasicLit)
return ok &&
lit.Kind == token.STRING &&
lit.Value == fmt.Sprintf("%q", val)
}
func IntLit(value int) *dst.BasicLit {
return &dst.BasicLit{
Kind: token.INT,
Value: fmt.Sprintf("%d", value),
}
}
func Block(stmt dst.Stmt) *dst.BlockStmt {
return &dst.BlockStmt{
List: []dst.Stmt{
stmt,
},
}
}
func BlockStmts(stmts ...dst.Stmt) *dst.BlockStmt {
return &dst.BlockStmt{
List: stmts,
}
}
func Exprs(exprs ...dst.Expr) []dst.Expr {
return exprs
}
func Stmts(stmts ...dst.Stmt) []dst.Stmt {
return stmts
}
func SelectorExpr(x dst.Expr, sel string) *dst.SelectorExpr {
return &dst.SelectorExpr{
X: dst.Clone(x).(dst.Expr),
Sel: Ident(sel),
}
}
func IndexExpr(x dst.Expr, index dst.Expr) *dst.IndexExpr {
return &dst.IndexExpr{
X: dst.Clone(x).(dst.Expr),
Index: dst.Clone(index).(dst.Expr),
}
}
func TypeAssertExpr(x dst.Expr, typ dst.Expr) *dst.TypeAssertExpr {
return &dst.TypeAssertExpr{
X: x,
Type: dst.Clone(typ).(dst.Expr),
}
}
func ParenExpr(x dst.Expr) *dst.ParenExpr {
return &dst.ParenExpr{
X: dst.Clone(x).(dst.Expr),
}
}
func NewField(name string, typ dst.Expr) *dst.Field {
newField := &dst.Field{
Names: []*dst.Ident{dst.NewIdent(name)},
Type: typ,
}
return newField
}
func BoolTrue() *dst.BasicLit {
return &dst.BasicLit{Value: IdentTrue}
}
func BoolFalse() *dst.BasicLit {
return &dst.BasicLit{Value: IdentFalse}
}
func IsInterfaceType(typ dst.Expr) bool {
_, ok := typ.(*dst.InterfaceType)
return ok
}
func IsEllipsis(typ dst.Expr) bool {
_, ok := typ.(*dst.Ellipsis)
return ok
}
func InterfaceType() *dst.InterfaceType {
return &dst.InterfaceType{Methods: &dst.FieldList{List: nil}}
}
func ArrayType(elem dst.Expr) *dst.ArrayType {
return &dst.ArrayType{Elt: elem}
}
func IfStmt(init dst.Stmt, cond dst.Expr,
body, elseBody *dst.BlockStmt) *dst.IfStmt {
return &dst.IfStmt{
Init: dst.Clone(init).(dst.Stmt),
Cond: dst.Clone(cond).(dst.Expr),
Body: dst.Clone(body).(*dst.BlockStmt),
Else: dst.Clone(elseBody).(*dst.BlockStmt),
}
}
func IfNotNilStmt(cond dst.Expr, body, elseBody *dst.BlockStmt) *dst.IfStmt {
var elseB dst.Stmt
if elseBody == nil {
elseB = nil
} else {
elseB = dst.Clone(elseBody).(dst.Stmt)
}
return &dst.IfStmt{
Cond: &dst.BinaryExpr{
X: dst.Clone(cond).(dst.Expr),
Op: token.NEQ,
Y: &dst.Ident{Name: IdentNil},
},
Body: dst.Clone(body).(*dst.BlockStmt),
Else: elseB,
}
}
func EmptyStmt() *dst.EmptyStmt {
return &dst.EmptyStmt{}
}
func ExprStmt(expr dst.Expr) *dst.ExprStmt {
return &dst.ExprStmt{X: dst.Clone(expr).(dst.Expr)}
}
func DeferStmt(call *dst.CallExpr) *dst.DeferStmt {
return &dst.DeferStmt{Call: dst.Clone(call).(*dst.CallExpr)}
}
func ReturnStmt(results []dst.Expr) *dst.ReturnStmt {
return &dst.ReturnStmt{Results: results}
}
func AssignStmt(lhs, rhs dst.Expr) *dst.AssignStmt {
return &dst.AssignStmt{
Lhs: []dst.Expr{lhs},
Tok: token.ASSIGN,
Rhs: []dst.Expr{rhs},
}
}
func DefineStmts(lhs, rhs []dst.Expr) *dst.AssignStmt {
return &dst.AssignStmt{
Lhs: lhs,
Tok: token.DEFINE,
Rhs: rhs,
}
}
func SwitchCase(list []dst.Expr, stmts []dst.Stmt) *dst.CaseClause {
return &dst.CaseClause{
List: list,
Body: stmts,
}
}
func AddStructField(decl dst.Decl, name string, typ string) {
gen, ok := decl.(*dst.GenDecl)
if !ok {
LogFatal("decl is not a GenDecl")
}
fd := NewField(name, Ident(typ))
st := gen.Specs[0].(*dst.TypeSpec).Type.(*dst.StructType)
st.Fields.List = append(st.Fields.List, fd)
}
func addImport(root *dst.File, paths ...string) *dst.GenDecl {
importStmt := &dst.GenDecl{Tok: token.IMPORT}
specs := make([]dst.Spec, 0)
for _, path := range paths {
spec := &dst.ImportSpec{
Path: &dst.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", path),
},
Name: &dst.Ident{Name: IdentIgnore},
}
specs = append(specs, spec)
}
importStmt.Specs = specs
root.Decls = append([]dst.Decl{importStmt}, root.Decls...)
return importStmt
}
func AddImportForcely(root *dst.File, paths ...string) *dst.GenDecl {
return addImport(root, paths...)
}
func RemoveImport(root *dst.File, path string) *dst.ImportSpec {
for j, decl := range root.Decls {
if genDecl, ok := decl.(*dst.GenDecl); ok &&
genDecl.Tok == token.IMPORT {
for i, spec := range genDecl.Specs {
if importSpec, ok := spec.(*dst.ImportSpec); ok {
if importSpec.Path.Value == fmt.Sprintf("%q", path) {
genDecl.Specs =
append(genDecl.Specs[:i],
genDecl.Specs[i+1:]...)
if len(genDecl.Specs) == 0 {
root.Decls =
append(root.Decls[:j], root.Decls[j+1:]...)
}
return importSpec
}
}
}
}
}
return nil
}
func FindImport(root *dst.File, path string) *dst.ImportSpec {
for _, decl := range root.Decls {
if genDecl, ok := decl.(*dst.GenDecl); ok &&
genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if importSpec, ok := spec.(*dst.ImportSpec); ok {
if importSpec.Path.Value == fmt.Sprintf("%q", path) {
return importSpec
}
}
}
}
}
return nil
}
func NewVarDecl(name string, paramTypes *dst.FieldList) *dst.GenDecl {
return &dst.GenDecl{
Tok: token.VAR,
Specs: []dst.Spec{
&dst.ValueSpec{
Names: []*dst.Ident{
{Name: name},
},
Type: &dst.FuncType{
Func: false,
Params: paramTypes,
},
},
},
}
}
func DereferenceOf(expr dst.Expr) dst.Expr {
return &dst.StarExpr{X: expr}
}
func HasReceiver(fn *dst.FuncDecl) bool {
return fn.Recv != nil && len(fn.Recv.List) > 0
}
// AST utilities
func FindFuncDecl(root *dst.File, name string) *dst.FuncDecl {
for _, decl := range root.Decls {
if fn, ok := decl.(*dst.FuncDecl); ok && fn.Name.Name == name {
return fn
}
}
return nil
}
func isValidRegex(pattern string) bool {
_, err := regexp.Compile(pattern)
return err == nil
}
func MatchFuncDecl(decl dst.Decl, function string, receiverType string) bool {
Assert(isValidRegex(function), "invalid function name pattern")
funcDecl, ok := decl.(*dst.FuncDecl)
if !ok {
return false
}
re := regexp.MustCompile("^" + function + "$") // strict match
if !re.MatchString(funcDecl.Name.Name) {
return false
}
if receiverType != "" {
re = regexp.MustCompile("^" + receiverType + "$") // strict match
if !HasReceiver(funcDecl) {
return re.MatchString("")
}
switch recvTypeExpr := funcDecl.Recv.List[0].Type.(type) {
case *dst.StarExpr:
if _, ok := recvTypeExpr.X.(*dst.Ident); !ok {
// This is a generic type, we don't support it yet
return false
}
t := "*" + recvTypeExpr.X.(*dst.Ident).Name
return re.MatchString(t)
case *dst.Ident:
t := recvTypeExpr.Name
return re.MatchString(t)
case *dst.IndexExpr:
// This is a generic type, we don't support it yet
return false
default:
msg := fmt.Sprintf("unexpected receiver type: %T", recvTypeExpr)
UnimplementedT(msg)
}
} else {
if HasReceiver(funcDecl) {
return false
}
}
return true
}
func MatchStructDecl(decl dst.Decl, structType string) bool {
if genDecl, ok := decl.(*dst.GenDecl); ok {
if genDecl.Tok == token.TYPE {
if typeSpec, ok := genDecl.Specs[0].(*dst.TypeSpec); ok {
if typeSpec.Name.Name == structType {
return true
}
}
}
}
return false
}
// AST Parser
// @@ N.B. DST framework provides a series of RestoreResolvers such
// as guess.New for resolving the package name from an importPath.
// However, its strategy is simply to guess by taking last section
// of the importpath as the package name. This can lead to issues
// where package names like github.com/foo/v2 are resolved as v2,
// while in reality, they might be foo. Incorrect resolutions can
// lead to some imports that should be present being rudely removed.
// To solve this issue, we disable DST's automatic Import management
// and use plain AST manipulation to add imports.
type AstParser struct {
fset *token.FileSet
dec *decorator.Decorator
}
func NewAstParser() *AstParser {
return &AstParser{
fset: token.NewFileSet(),
}
}
func (ap *AstParser) FindPosition(node dst.Node) token.Position {
astNode := ap.dec.Ast.Nodes[node]
if astNode == nil {
return token.Position{Filename: "", Line: -1, Column: -1} // Invalid
}
return ap.fset.Position(astNode.Pos())
}
// ParseSnippet parses the AST from incomplete source code snippet.
func (ap *AstParser) ParseSnippet(codeSnippnet string) ([]dst.Stmt, error) {
Assert(codeSnippnet != "", "empty code snippet")
snippet := "package main; func _() {" + codeSnippnet + "}"
file, err := decorator.ParseFile(ap.fset, "", snippet, 0)
if err != nil {
return nil, errc.New(errc.ErrParseCode, err.Error())
}
return file.Decls[0].(*dst.FuncDecl).Body.List, nil
}
// ParseSource parses the AST from complete source code.
func (ap *AstParser) ParseSource(source string) (*dst.File, error) {
Assert(source != "", "empty source")
ap.dec = decorator.NewDecorator(ap.fset)
dstRoot, err := ap.dec.Parse(source)
if err != nil {
return nil, errc.New(errc.ErrParseCode, err.Error())
}
return dstRoot, nil
}
func (ap *AstParser) ParseFile(filePath string, mode parser.Mode) (*dst.File, error) {
name := filepath.Base(filePath)
file, err := os.Open(filePath)
if err != nil {
return nil, errc.New(errc.ErrOpenFile, err.Error())
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
LogFatal("failed to close file %s: %v", file.Name(), err)
}
}(file)
astFile, err := parser.ParseFile(ap.fset, name, file, mode)
if err != nil {
return nil, errc.New(errc.ErrParseCode, err.Error())
}
ap.dec = decorator.NewDecorator(ap.fset)
dstFile, err := ap.dec.DecorateFile(astFile)
if err != nil {
return nil, errc.New(errc.ErrParseCode, err.Error())
}
return dstFile, nil
}
func ParseAstFromFileOnlyPackage(filePath string) (*dst.File, error) {
return NewAstParser().ParseFile(filePath, parser.PackageClauseOnly)
}
func ParseAstFromFileFast(filePath string) (*dst.File, error) {
return NewAstParser().ParseFile(filePath, parser.SkipObjectResolution)
}
// ParseAstFromFile parses the AST from complete source file.
func ParseAstFromFile(filePath string) (*dst.File, error) {
return NewAstParser().ParseFile(filePath, parser.ParseComments)
}
// WriteAstToFile writes the AST to source file.
func WriteAstToFile(astRoot *dst.File, filePath string) (string, error) {
file, err := os.Create(filePath)
if err != nil {
return "", errc.New(errc.ErrCreateFile, err.Error())
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
LogFatal("failed to close file %s: %v", file.Name(), err)
}
}(file)
r := decorator.NewRestorer()
err = r.Fprint(file, astRoot)
if err != nil {
return "", errc.New(errc.ErrParseCode, err.Error())
}
return file.Name(), nil
}