grpc/server/ctxlogger/ctxlogger.go (97 lines of code) (raw):

package ctxlogger import ( "context" "encoding/json" log "log/slog" loggable "buf.build/gen/go/service-hub/loggable/protocolbuffers/go/proto" "github.com/Azure/aks-middleware/grpc/common" "google.golang.org/grpc" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" ) // ExtractFunction extracts information from the ctx and/or the request and put it in the logger. // This function is called before the application's handler is called so that it can add more context // to the logger. type ExtractFunction func(ctx context.Context, req any, info *grpc.UnaryServerInfo, logger *log.Logger) *log.Logger type loggerKeyType int const ( loggerKey loggerKeyType = iota ) func WithLogger(ctx context.Context, logger *log.Logger) context.Context { return context.WithValue(ctx, loggerKey, logger) } func GetLogger(ctx context.Context) *log.Logger { logger := log.Default().With("src", "self gen, not available in ctx") if ctx == nil { return logger } if ctxlogger, ok := ctx.Value(loggerKey).(*log.Logger); ok { return ctxlogger } return logger } // UnaryServerInterceptor returns a UnaryServerInterceptor. // extractFunction can be nil if the defaultExtractFunction() is good enough. // extractFunction is for ctx or request specific information. // For information that doesn't change with ctx/request, pass the information via logger. // The first registerred interceptor will be called first. // Need to register requestid first to add request-id. // Then the logger can get the request-id. func UnaryServerInterceptor(logger *log.Logger, extractFunction ExtractFunction) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { l := logger if extractFunction != nil { l = extractFunction(ctx, req, info, l) } else { l = defaultExtractFunction(ctx, req, info, l) } l = l.With(requestContentLogKey, FilterLogs(req)) ctx = WithLogger(ctx, l) // log.Print("logger ctx: ", ctx) return handler(ctx, req) } } const ( methodLogKey = "method" requestContentLogKey = "request" ) func defaultExtractFunction(ctx context.Context, req any, info *grpc.UnaryServerInfo, logger *log.Logger) *log.Logger { l := logger l = l.With(methodLogKey, info.FullMethod) l = l.With(common.GetFields(ctx)...) return l } func filterLoggableFields(currentMap map[string]interface{}, message protoreflect.Message) map[string]interface{} { // Check if the map or the message is nil if currentMap == nil || message == nil { return currentMap } for name, value := range currentMap { // Get the field descriptor by name fd := message.Descriptor().Fields().ByName(protoreflect.Name(name)) // Check if the field descriptor is nil if fd == nil { continue } opts := fd.Options() fdOpts := opts.(*descriptorpb.FieldOptions) loggable := proto.GetExtension(fdOpts, loggable.E_Loggable) // Delete the field from the map if it is not loggable if !loggable.(bool) { delete(currentMap, name) continue } // Check if the value is another map[string]interface{} if subMap, ok := value.(map[string]interface{}); ok { // Check if its a simple map or one containing messages if fd.Message() != nil && !fd.Message().IsMapEntry() { // Get the sub-message for the field subMessage := message.Get(fd).Message() // Call the helper function recursively on the subMap and subMessage currentMap[name] = filterLoggableFields(subMap, subMessage) } } } return currentMap } func FilterLogs(req any) map[string]interface{} { in, ok := req.(proto.Message) var reqPayload map[string]interface{} if ok { // Get the protoreflect.Message interface for the message message := in.ProtoReflect() // Marshal the message to JSON bytes jsonBytes, err := protojson.Marshal(message.Interface().(protoreflect.ProtoMessage)) if err != nil { log.Error(err.Error()) } // Unmarshal the JSON bytes to a map[string]interface{} err = json.Unmarshal(jsonBytes, &reqPayload) if err != nil { log.Error(err.Error()) } // Filter out the fields that are not loggable using the helper function reqPayload = filterLoggableFields(reqPayload, message) } return reqPayload }