tools/go-agent/instrument/runtime/instrument.go (267 lines of code) (raw):

// Licensed to Apache Software Foundation (ASF) under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Apache Software Foundation (ASF) licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package runtime import ( "github.com/dave/dst" "github.com/dave/dst/dstutil" "github.com/apache/skywalking-go/tools/go-agent/instrument/api" "github.com/apache/skywalking-go/tools/go-agent/instrument/consts" "github.com/apache/skywalking-go/tools/go-agent/tools" ) var defaultInternalAtomicPath = "runtime/internal/atomic" type Instrument struct { goIDType string opts *api.CompileOptions } func NewInstrument() *Instrument { return &Instrument{} } func (r *Instrument) CouldHandle(opts *api.CompileOptions) bool { r.opts = opts return opts.Package == "runtime" } func (r *Instrument) FilterAndEdit(path string, curFile *dst.File, cursor *dstutil.Cursor, allFiles []*dst.File) bool { switch n := cursor.Node().(type) { case *dst.TypeSpec: if n.Name != nil && n.Name.Name != "g" { return false } st, ok := n.Type.(*dst.StructType) if !ok { return false } for _, f := range st.Fields.List { if len(f.Names) > 0 && f.Names[0].Name == "goid" { r.goIDType = f.Type.(*dst.Ident).Name } } // append the tls field st.Fields.List = append(st.Fields.List, &dst.Field{ Names: []*dst.Ident{dst.NewIdent(consts.TLSFieldName)}, Type: dst.NewIdent("interface{}")}) tools.LogWithStructEnhance("runtime", "g", consts.TLSFieldName, "tls field") return true case *dst.FuncDecl: if n.Name.Name != "newproc1" { return false } if len(n.Type.Results.List) != 1 { return false } expectedParamCount := 3 if r.opts.CheckGoVersionGreaterOrEqual(1, 23) { expectedParamCount = 5 } if len(n.Type.Params.List) != expectedParamCount { return false } parameters := tools.EnhanceParameterNames(n.Type.Params, tools.FieldListTypeParam) results := tools.EnhanceParameterNames(n.Type.Results, tools.FieldListTypeResult) tools.InsertStmtsBeforeBody(n.Body, `defer func() { {{(index .Results 0).Name}}.{{.TLSField}} = goroutineChange({{(index .Parameters 1).Name}}.{{.TLSField}}) }() `, struct { Parameters []*tools.ParameterInfo Results []*tools.ParameterInfo TLSField string OperatorField string SnapshotInterface string }{ Parameters: parameters, Results: results, TLSField: consts.TLSFieldName, OperatorField: consts.GlobalTracerFieldName, SnapshotInterface: consts.GlobalTracerSnapshotInterface, }) tools.LogWithMethodEnhance("runtime", "", "newproc1", "support cross goroutine context propagating") return true } return false } func (r *Instrument) AfterEnhanceFile(fromPath, newPath string) error { return nil } func (r *Instrument) parseInternalAtomicPath() string { if r.opts.CheckGoVersionGreaterOrEqual(1, 23) { return "internal/runtime/atomic" } return defaultInternalAtomicPath } // nolint func (r *Instrument) WriteExtraFiles(dir string) ([]string, error) { return tools.WriteMultipleFile(dir, map[string]string{ "skywalking_tls_operator.go": tools.ExecuteTemplate(`package runtime import ( _ "unsafe" atomic "{{.InternalAtomicPath}}" ) var {{.GlobalTracerFieldName}} interface{} var {{.GlobalLoggerFieldName}} interface{} var {{.GlobalTracerInitNotifyFieldName}} = make([]func(), 0) var _metricsRegisterLockerVal int32 = 0 var _metricsRegisterLocker = &_metricsRegisterLockerVal var {{.MetricsRegisterFieldName}} = make([]interface{}, 0) var {{.MetricsHookFieldName}} = make([]func(), 0) //go:linkname {{.TLSGetMethod}} {{.TLSGetMethod}} var {{.TLSGetMethod}} = _skywalking_tls_get_impl //go:linkname {{.TLSSetMethod}} {{.TLSSetMethod}} var {{.TLSSetMethod}} = _skywalking_tls_set_impl //go:linkname {{.GlobalOperatorSetMethodName}} {{.GlobalOperatorSetMethodName}} var {{.GlobalOperatorSetMethodName}} = _skywalking_global_operator_set_impl //go:linkname {{.GlobalOperatorGetMethodName}} {{.GlobalOperatorGetMethodName}} var {{.GlobalOperatorGetMethodName}} = _skywalking_global_operator_get_impl //go:linkname {{.GlobalLoggerSetMethodName}} {{.GlobalLoggerSetMethodName}} var {{.GlobalLoggerSetMethodName}} = _skywalking_global_logger_set_impl //go:linkname {{.GlobalLoggerGetMethodName}} {{.GlobalLoggerGetMethodName}} var {{.GlobalLoggerGetMethodName}} = _skywalking_global_logger_get_impl //go:linkname {{.GoroutineIDGetterMethodName}} {{.GoroutineIDGetterMethodName}} var {{.GoroutineIDGetterMethodName}} = _skywalking_get_goid_impl //go:linkname {{.GlobalTracerInitNotifyMethodName}} {{.GlobalTracerInitNotifyMethodName}} var {{.GlobalTracerInitNotifyMethodName}} = _skywalking_global_tracer_init_notify_impl //go:linkname {{.GlobalTracerInitNotifyGetMethodName}} {{.GlobalTracerInitNotifyGetMethodName}} var {{.GlobalTracerInitNotifyGetMethodName}} = _skywalking_global_tracer_init_get_notify_impl //go:linkname {{.MetricsRegisterAppendMethodName}} {{.MetricsRegisterAppendMethodName}} var {{.MetricsRegisterAppendMethodName}} = _skywalking_metrics_register_append_impl //go:linkname {{.MetricsObtainMethodName}} {{.MetricsObtainMethodName}} var {{.MetricsObtainMethodName}} = _skywalking_metrics_obtain_impl //go:nosplit func _skywalking_get_goid_impl() int64 { return {{.GoroutineIDCaster}} } //go:nosplit func _skywalking_tls_get_impl() interface{} { return getg().m.curg.{{.TLSFiledName}} } //go:nosplit func _skywalking_tls_set_impl(v interface{}) { getg().m.curg.{{.TLSFiledName}} = v } //go:nosplit func _skywalking_global_operator_set_impl(v interface{}) { {{.GlobalTracerFieldName}} = v } //go:nosplit func _skywalking_global_operator_get_impl() interface{} { return {{.GlobalTracerFieldName}} } //go:nosplit func _skywalking_global_logger_set_impl(v interface{}) { {{.GlobalLoggerFieldName}} = v } //go:nosplit func _skywalking_global_logger_get_impl() interface{} { return {{.GlobalLoggerFieldName}} } //go:nosplit func _skywalking_global_tracer_init_notify_impl(fun func()) { {{.GlobalTracerInitNotifyFieldName}} = append({{.GlobalTracerInitNotifyFieldName}}, fun) } //go:nosplit func _skywalking_global_tracer_init_get_notify_impl() []func() { return {{.GlobalTracerInitNotifyFieldName}} } //go:nosplit func _skywalking_metrics_register_append_impl(v interface{}) { for { tmp := atomic.Loadint32(_metricsRegisterLocker) if atomic.Casint32(_metricsRegisterLocker, tmp, tmp+1) { {{.MetricsRegisterFieldName}} = append({{.MetricsRegisterFieldName}}, v) break } } } //go:nosplit func _skywalking_metrics_obtain_impl() ([]interface{}, []func()) { for { tmp := atomic.Loadint32(_metricsRegisterLocker) if tmp == 0 { return nil, nil } if atomic.Casint32(_metricsRegisterLocker, tmp, 0) { registers := {{.MetricsRegisterFieldName}} {{.MetricsRegisterFieldName}} = make([]interface{}, 0) hooks := {{.MetricsHookFieldName}} {{.MetricsHookFieldName}} = make([]func(), 0) return registers, hooks } } } //go:nosplit func _skywalking_metrics_hook_append_impl(f func()) { for { tmp := atomic.Loadint32(_metricsRegisterLocker) if atomic.Casint32(_metricsRegisterLocker, tmp, tmp+1) { {{.MetricsHookFieldName}} = append({{.MetricsHookFieldName}}, f) break } } } type ContextSnapshoter interface { TakeSnapShot(val interface{}) interface{} } func goroutineChange(tls interface{}) interface{} { if tls == nil { return nil } if taker, ok := tls.(ContextSnapshoter); ok { return taker.TakeSnapShot(tls) } return tls } `, struct { TLSFiledName string TLSGetMethod string TLSSetMethod string GlobalTracerFieldName string GlobalTracerSnapshotInterface string GlobalOperatorSetMethodName string GlobalOperatorGetMethodName string GlobalLoggerFieldName string GlobalLoggerSetMethodName string GlobalLoggerGetMethodName string GoroutineIDGetterMethodName string GoroutineIDCaster string GlobalTracerInitNotifyFieldName string GlobalTracerInitNotifyMethodName string GlobalTracerInitNotifyGetMethodName string MetricsRegisterFieldName string MetricsRegisterAppendMethodName string MetricsObtainMethodName string MetricsHookFieldName string MetricsHookAppendMethodName string InternalAtomicPath string }{ TLSFiledName: consts.TLSFieldName, TLSGetMethod: consts.TLSGetMethodName, TLSSetMethod: consts.TLSSetMethodName, GlobalTracerFieldName: consts.GlobalTracerFieldName, GlobalTracerSnapshotInterface: consts.GlobalTracerSnapshotInterface, GlobalOperatorSetMethodName: consts.GlobalTracerSetMethodName, GlobalOperatorGetMethodName: consts.GlobalTracerGetMethodName, GlobalLoggerFieldName: consts.GlobalLoggerFieldName, GlobalLoggerSetMethodName: consts.GlobalLoggerSetMethodName, GlobalLoggerGetMethodName: consts.GlobalLoggerGetMethodName, GoroutineIDGetterMethodName: consts.CurrentGoroutineIDGetMethodName, GoroutineIDCaster: r.generateCastGoID("getg().m.curg.goid"), GlobalTracerInitNotifyFieldName: consts.GlobalTracerInitNotifyFieldName, GlobalTracerInitNotifyMethodName: consts.GlobalTracerInitAppendNotifyMethodName, GlobalTracerInitNotifyGetMethodName: consts.GlobalTracerInitGetNotifyMethodName, MetricsRegisterFieldName: consts.MetricsRegisterFieldName, MetricsRegisterAppendMethodName: consts.MetricsRegisterAppendMethodName, MetricsObtainMethodName: consts.MetricsObtainMethodName, MetricsHookFieldName: consts.MetricsHookFieldName, MetricsHookAppendMethodName: consts.MetricsHookAppendMethodName, InternalAtomicPath: r.parseInternalAtomicPath(), }), }) } func (r *Instrument) generateCastGoID(val string) string { switch r.goIDType { case "int64": return val case "uint64": case "int32": case "uint32": case "int": case "uint": default: panic("cannot find goid type in the g struct or the type is not supported: " + r.goIDType) } return "int64(" + val + ")" }