gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/server_interceptors.go (about)

     1  package grpccorrelation
     2  
     3  import (
     4  	"context"
     5  
     6  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
     7  	"gitlab.com/gitlab-org/labkit/correlation"
     8  	"google.golang.org/grpc"
     9  	"google.golang.org/grpc/metadata"
    10  )
    11  
    12  func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool) (context.Context, string) {
    13  	var correlationID string
    14  	md, ok := metadata.FromIncomingContext(ctx)
    15  	if ok {
    16  		if propagateIncomingCorrelationID {
    17  			// Extract correlation_id
    18  			correlationID = CorrelationIDFromMetadata(md)
    19  		}
    20  
    21  		// Extract client name
    22  		clientNames := md.Get(metadataClientNameKey)
    23  		if len(clientNames) > 0 {
    24  			ctx = correlation.ContextWithClientName(ctx, clientNames[0])
    25  		}
    26  	}
    27  	if correlationID == "" {
    28  		correlationID = correlation.SafeRandomID()
    29  	}
    30  	ctx = correlation.ContextWithCorrelation(ctx, correlationID)
    31  	return ctx, correlationID
    32  }
    33  
    34  // CorrelationIDFromMetadata can be used to extract correlation ID from request/response metadata.
    35  // Returns an empty string if correlation ID is not found.
    36  func CorrelationIDFromMetadata(md metadata.MD) string {
    37  	values := md.Get(metadataCorrelatorKey)
    38  	if len(values) > 0 {
    39  		return values[0]
    40  	}
    41  	return ""
    42  }
    43  
    44  // UnaryServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
    45  func UnaryServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.UnaryServerInterceptor {
    46  	config := applyServerCorrelationInterceptorOptions(opts)
    47  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    48  		ctx, correlationID := extractFromContext(ctx, config.propagateIncomingCorrelationID)
    49  		if config.reversePropagateCorrelationID {
    50  			sts := grpc.ServerTransportStreamFromContext(ctx)
    51  			err := sts.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
    52  			if err != nil {
    53  				return nil, err
    54  			}
    55  		}
    56  		return handler(ctx, req)
    57  	}
    58  }
    59  
    60  // StreamServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
    61  func StreamServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.StreamServerInterceptor {
    62  	config := applyServerCorrelationInterceptorOptions(opts)
    63  	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    64  		var correlationID string
    65  		wrapped := grpc_middleware.WrapServerStream(ss)
    66  		wrapped.WrappedContext, correlationID = extractFromContext(ss.Context(), config.propagateIncomingCorrelationID)
    67  		if config.reversePropagateCorrelationID {
    68  			err := wrapped.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
    69  			if err != nil {
    70  				return err
    71  			}
    72  		}
    73  
    74  		return handler(srv, wrapped)
    75  	}
    76  }