correlation/grpc/server_interceptors.go (62 lines of code) (raw):

package grpccorrelation import ( "context" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "gitlab.com/gitlab-org/labkit/correlation" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool) (context.Context, string) { var correlationID string md, ok := metadata.FromIncomingContext(ctx) if ok { if propagateIncomingCorrelationID { // Extract correlation_id correlationID = CorrelationIDFromMetadata(md) } // Extract client name clientNames := md.Get(metadataClientNameKey) if len(clientNames) > 0 { ctx = correlation.ContextWithClientName(ctx, clientNames[0]) } } if correlationID == "" { correlationID = correlation.SafeRandomID() } ctx = correlation.ContextWithCorrelation(ctx, correlationID) return ctx, correlationID } // CorrelationIDFromMetadata can be used to extract correlation ID from request/response metadata. // Returns an empty string if correlation ID is not found. func CorrelationIDFromMetadata(md metadata.MD) string { values := md.Get(metadataCorrelatorKey) if len(values) > 0 { return values[0] } return "" } // UnaryServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services. func UnaryServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.UnaryServerInterceptor { config := applyServerCorrelationInterceptorOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { ctx, correlationID := extractFromContext(ctx, config.propagateIncomingCorrelationID) if config.reversePropagateCorrelationID { sts := grpc.ServerTransportStreamFromContext(ctx) err := sts.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID)) if err != nil { return nil, err } } return handler(ctx, req) } } // StreamServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services. func StreamServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.StreamServerInterceptor { config := applyServerCorrelationInterceptorOptions(opts) return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { var correlationID string wrapped := grpc_middleware.WrapServerStream(ss) wrapped.WrappedContext, correlationID = extractFromContext(ss.Context(), config.propagateIncomingCorrelationID) if config.reversePropagateCorrelationID { err := wrapped.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID)) if err != nil { return err } } return handler(srv, wrapped) } }