correlation/grpc/server_interceptors.go (62 lines of code) (raw):
package grpccorrelation
import (
"context"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"gitlab.com/gitlab-org/labkit/correlation"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool) (context.Context, string) {
var correlationID string
md, ok := metadata.FromIncomingContext(ctx)
if ok {
if propagateIncomingCorrelationID {
// Extract correlation_id
correlationID = CorrelationIDFromMetadata(md)
}
// Extract client name
clientNames := md.Get(metadataClientNameKey)
if len(clientNames) > 0 {
ctx = correlation.ContextWithClientName(ctx, clientNames[0])
}
}
if correlationID == "" {
correlationID = correlation.SafeRandomID()
}
ctx = correlation.ContextWithCorrelation(ctx, correlationID)
return ctx, correlationID
}
// CorrelationIDFromMetadata can be used to extract correlation ID from request/response metadata.
// Returns an empty string if correlation ID is not found.
func CorrelationIDFromMetadata(md metadata.MD) string {
values := md.Get(metadataCorrelatorKey)
if len(values) > 0 {
return values[0]
}
return ""
}
// UnaryServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
func UnaryServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.UnaryServerInterceptor {
config := applyServerCorrelationInterceptorOptions(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx, correlationID := extractFromContext(ctx, config.propagateIncomingCorrelationID)
if config.reversePropagateCorrelationID {
sts := grpc.ServerTransportStreamFromContext(ctx)
err := sts.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
if err != nil {
return nil, err
}
}
return handler(ctx, req)
}
}
// StreamServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
func StreamServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.StreamServerInterceptor {
config := applyServerCorrelationInterceptorOptions(opts)
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var correlationID string
wrapped := grpc_middleware.WrapServerStream(ss)
wrapped.WrappedContext, correlationID = extractFromContext(ss.Context(), config.propagateIncomingCorrelationID)
if config.reversePropagateCorrelationID {
err := wrapped.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
if err != nil {
return err
}
}
return handler(srv, wrapped)
}
}