tools/go-agent/cmd/injector.go (287 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package main
import (
"fmt"
"go/parser"
"go/token"
"io/fs"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"github.com/apache/skywalking-go/tools/go-agent/tools"
"github.com/dave/dst"
"github.com/dave/dst/decorator"
)
const (
projectBaseImportPath = "github.com/apache/skywalking-go"
goModFileName = "go.mod"
swImportFileName = "skywalking_inject.go"
)
var (
swImportFileContent = fmt.Sprintf(`// Code generated by skywalking-go-agent. DO NOT EDIT.
package main
import _ "%s"`, projectBaseImportPath)
gitSHARegex = regexp.MustCompile(`^[0-9a-fA-F]{40}$|^[0-9a-fA-F]{7}$`)
)
type projectInjector struct {
}
func InjectProject(flags *EnhancementToolFlags) error {
stat, err := os.Stat(flags.Inject)
if err != nil {
return err
}
if version == "" {
return fmt.Errorf("version is empty, please use the release version of skywalking-go")
}
abs, err := filepath.Abs(flags.Inject)
if err != nil {
return err
}
injector := &projectInjector{}
if stat.IsDir() {
return injector.injectDir(abs, flags.AllProjects)
}
return injector.injectFile(abs)
}
func (i *projectInjector) injectDir(path string, allProjects bool) error {
if !i.findGoModFileInDir(path) {
return fmt.Errorf("cannot fing go.mod file in %s, plase make sure that your inject path is a project directory", path)
}
// find all projects and main directory
projects, err := i.findProjects(path, allProjects)
if err != nil {
return err
}
// filter validated projects
validatedProjects := make([]*projectWithMainDirectory, 0)
for _, project := range projects {
if project.isValid() {
validatedProjects = append(validatedProjects, project)
}
}
fmt.Printf("total %d validate projects found\n", len(validatedProjects))
// inject library
for _, project := range validatedProjects {
if err := i.injectLibraryInRoot(project.ProjectPath); err != nil {
return err
}
for _, mainDir := range project.MainPackageDirs {
contains, err := i.alreadyContainsLibraryImport(mainDir)
if err != nil {
return err
}
if contains {
fmt.Printf("main package %s already contains imports, skip\n", mainDir)
continue
}
// append a new file to the main package
if err := i.appendNewImportFile(mainDir); err != nil {
return fmt.Errorf("append new import file failed in %s, %v", mainDir, err)
}
fmt.Printf("append new import file success in %s\n", mainDir)
}
}
return nil
}
func (i *projectInjector) injectFile(path string) error {
if !strings.HasSuffix(path, ".go") {
return fmt.Errorf("only support inject go file, %s is not a go file", path)
}
dir := filepath.Dir(path)
if !i.findGoModFileInDir(dir) {
return fmt.Errorf("cannot fing go.mod file in %s", dir)
}
// inject library
if err := i.injectLibraryInRoot(dir); err != nil {
return err
}
// if only inject to a file, then just add import into the file
return i.injectImportInFile(path)
}
func (i *projectInjector) findGoModFileInDir(dir string) bool {
path := filepath.Join(dir, goModFileName)
stat, err := os.Stat(path)
if err != nil {
return false
}
return stat != nil
}
func (i *projectInjector) injectLibraryInRoot(dir string) error {
v := version
if !gitSHARegex.MatchString(version) {
v = "v" + version
}
fmt.Printf("injecting skywalking-go@%s depenedency into %s\n", v, dir)
command := exec.Command("go", "get", "github.com/apache/skywalking-go@"+v)
command.Dir = dir
command.Stdin = os.Stdin
command.Stdout = os.Stdout
command.Stderr = os.Stderr
err := command.Run()
if err != nil {
return err
}
return nil
}
func (i *projectInjector) injectImportInFile(path string) error {
filename := filepath.Base(path)
content, err := os.ReadFile(path)
if err != nil {
return err
}
f, err := decorator.ParseFile(nil, filename, content, parser.ParseComments)
if err != nil {
return fmt.Errorf("parse file %s failed, %v", path, err)
}
if i.addingProjectImportInFileAndRewrite(f) {
fmt.Printf("already existing library import in %s, skip\n", path)
}
fileContent, err := tools.GenerateDSTFileContent(f, nil)
if err != nil {
return fmt.Errorf("generate file content failed, %v", err)
}
err = os.WriteFile(path, []byte(fileContent), 0o600)
if err != nil {
return fmt.Errorf("rewrite the file %s failed, %v", path, err)
}
fmt.Printf("adding skywalking-go import into the file: %s", path)
return nil
}
func (i *projectInjector) addingProjectImportInFileAndRewrite(f *dst.File) bool {
var latestImportDel *dst.GenDecl
var existingImport bool
for _, decl := range f.Decls {
if gen, ok := decl.(*dst.GenDecl); ok && gen != nil && gen.Tok == token.IMPORT {
latestImportDel = gen
if !existingImport && i.containsImport(gen) {
existingImport = true
}
}
}
if existingImport {
return true
}
if latestImportDel == nil {
latestImportDel = &dst.GenDecl{
Tok: token.IMPORT,
Specs: []dst.Spec{},
}
f.Decls = append([]dst.Decl{latestImportDel}, f.Decls...)
}
latestImportDel.Specs = append(latestImportDel.Specs, &dst.ImportSpec{
Name: dst.NewIdent("_"),
Path: &dst.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf("%q", projectBaseImportPath),
},
})
return false
}
func (i *projectInjector) findProjects(currentDir string, all bool) ([]*projectWithMainDirectory, error) {
result := make([]*projectWithMainDirectory, 0)
stack := make([]*projectWithMainDirectory, 0)
currentStackPrefix := ""
err := filepath.WalkDir(currentDir, func(path string, d fs.DirEntry, err error) error {
if !d.IsDir() {
return nil
}
if strings.HasPrefix(filepath.Base(path), ".") {
return filepath.SkipDir
}
if currentStackPrefix != "" && !strings.HasPrefix(path, currentStackPrefix) {
stack = stack[:len(stack)-1]
currentStackPrefix = stack[len(stack)-1].ProjectPath
}
if f, e := os.Stat(filepath.Join(path, goModFileName)); e == nil && f != nil {
if len(stack) > 0 && !all {
return filepath.SkipDir
}
info := &projectWithMainDirectory{
ProjectPath: path,
}
result = append(result, info)
stack = append(stack, info)
currentStackPrefix = path
}
if mainPackage, e := i.containsMainPackageInCurrentDirectory(path); e != nil {
return err
} else if mainPackage {
currentModule := stack[len(stack)-1]
currentModule.MainPackageDirs = append(currentModule.MainPackageDirs, path)
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
func (i *projectInjector) containsMainPackageInCurrentDirectory(dir string) (bool, error) {
readDir, err := os.ReadDir(dir)
if err != nil {
return false, fmt.Errorf("read dir %s failed, %v", dir, err)
}
for _, file := range readDir {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".go") {
continue
}
parseFile, err := parser.ParseFile(token.NewFileSet(), filepath.Join(dir, file.Name()), nil, parser.PackageClauseOnly)
if err != nil {
return false, err
}
if parseFile.Name.Name == "main" {
return true, nil
}
// only needs to check the first .go file, other files should be same
return false, nil
}
return false, nil
}
func (i *projectInjector) alreadyContainsLibraryImport(dir string) (bool, error) {
readDir, err := os.ReadDir(dir)
if err != nil {
return false, fmt.Errorf("reding directory %s failure, %v", dir, err)
}
for _, f := range readDir {
if f.IsDir() {
continue
}
if !strings.HasSuffix(f.Name(), ".go") {
continue
}
file, err := os.ReadFile(filepath.Join(dir, f.Name()))
if err != nil {
return false, fmt.Errorf("read file %s failed, %v", f.Name(), err)
}
dstFile, err := decorator.ParseFile(nil, f.Name(), file, parser.ImportsOnly)
if err != nil {
return false, fmt.Errorf("parsing file %s failed, %v", f.Name(), err)
}
var existingImport = false
for _, decl := range dstFile.Decls {
if gen, ok := decl.(*dst.GenDecl); ok && gen != nil && gen.Tok == token.IMPORT &&
!existingImport && i.containsImport(gen) {
existingImport = true
}
}
if existingImport {
return true, nil
}
}
return false, nil
}
func (i *projectInjector) containsImport(imp *dst.GenDecl) bool {
for _, spec := range imp.Specs {
if i, ok := spec.(*dst.ImportSpec); !ok || i == nil {
continue
} else if i.Path != nil && i.Path.Value == fmt.Sprintf("%q", projectBaseImportPath) {
return true
}
}
return false
}
func (i *projectInjector) appendNewImportFile(dir string) error {
importFilePath := filepath.Join(dir, swImportFileName)
return os.WriteFile(importFilePath, []byte(swImportFileContent), 0o600)
}
type projectWithMainDirectory struct {
ProjectPath string
MainPackageDirs []string
}
func (p *projectWithMainDirectory) isValid() bool {
return p.ProjectPath != "" && len(p.MainPackageDirs) > 0
}