gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/server_interceptors.go (about) 1 package grpccorrelation 2 3 import ( 4 "context" 5 6 grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" 7 "gitlab.com/gitlab-org/labkit/correlation" 8 "google.golang.org/grpc" 9 "google.golang.org/grpc/metadata" 10 ) 11 12 func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool) (context.Context, string) { 13 var correlationID string 14 md, ok := metadata.FromIncomingContext(ctx) 15 if ok { 16 if propagateIncomingCorrelationID { 17 // Extract correlation_id 18 correlationID = CorrelationIDFromMetadata(md) 19 } 20 21 // Extract client name 22 clientNames := md.Get(metadataClientNameKey) 23 if len(clientNames) > 0 { 24 ctx = correlation.ContextWithClientName(ctx, clientNames[0]) 25 } 26 } 27 if correlationID == "" { 28 correlationID = correlation.SafeRandomID() 29 } 30 ctx = correlation.ContextWithCorrelation(ctx, correlationID) 31 return ctx, correlationID 32 } 33 34 // CorrelationIDFromMetadata can be used to extract correlation ID from request/response metadata. 35 // Returns an empty string if correlation ID is not found. 36 func CorrelationIDFromMetadata(md metadata.MD) string { 37 values := md.Get(metadataCorrelatorKey) 38 if len(values) > 0 { 39 return values[0] 40 } 41 return "" 42 } 43 44 // UnaryServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services. 45 func UnaryServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.UnaryServerInterceptor { 46 config := applyServerCorrelationInterceptorOptions(opts) 47 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 48 ctx, correlationID := extractFromContext(ctx, config.propagateIncomingCorrelationID) 49 if config.reversePropagateCorrelationID { 50 sts := grpc.ServerTransportStreamFromContext(ctx) 51 err := sts.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID)) 52 if err != nil { 53 return nil, err 54 } 55 } 56 return handler(ctx, req) 57 } 58 } 59 60 // StreamServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services. 61 func StreamServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.StreamServerInterceptor { 62 config := applyServerCorrelationInterceptorOptions(opts) 63 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 64 var correlationID string 65 wrapped := grpc_middleware.WrapServerStream(ss) 66 wrapped.WrappedContext, correlationID = extractFromContext(ss.Context(), config.propagateIncomingCorrelationID) 67 if config.reversePropagateCorrelationID { 68 err := wrapped.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID)) 69 if err != nil { 70 return err 71 } 72 } 73 74 return handler(srv, wrapped) 75 } 76 }