gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/client_interceptors_test.go (about) 1 package grpccorrelation 2 3 import ( 4 "context" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 "gitlab.com/gitlab-org/labkit/correlation" 9 "google.golang.org/grpc" 10 "google.golang.org/grpc/metadata" 11 ) 12 13 const ( 14 correlationID = "CORRELATION_ID" 15 clientName = "CLIENT_NAME" 16 methodName = "METHOD_NAME" 17 ) 18 19 func verifyContextMetadata(ctx context.Context, require *require.Assertions, expCorrelationID, expClientName string) { 20 md, ok := metadata.FromOutgoingContext(ctx) 21 require.True(ok) 22 ids := md.Get(metadataCorrelatorKey) 23 require.Less(0, len(ids)) 24 require.Equal(expCorrelationID, ids[0]) 25 26 clientNames := md.Get(metadataClientNameKey) 27 require.Less(0, len(clientNames)) 28 require.Equal(expClientName, clientNames[0]) 29 } 30 31 func getTestUnaryInvoker(require *require.Assertions, expCorrelationID, expClientName string) grpc.UnaryInvoker { 32 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { 33 verifyContextMetadata(ctx, require, expCorrelationID, expClientName) 34 return nil 35 } 36 } 37 38 func getTestStreamer(require *require.Assertions, expCorrelationID, expClientName string) grpc.Streamer { 39 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { 40 verifyContextMetadata(ctx, require, expCorrelationID, expClientName) 41 return nil, nil 42 } 43 } 44 45 func TestUnaryClientCorrelationInterceptor(t *testing.T) { 46 require := require.New(t) 47 48 clientInterceptor := UnaryClientCorrelationInterceptor(WithClientName(clientName)) 49 50 ctx := correlation.ContextWithCorrelation(context.Background(), correlationID) 51 err := clientInterceptor( 52 ctx, 53 methodName, 54 nil, 55 nil, 56 nil, 57 getTestUnaryInvoker(require, correlationID, clientName), 58 ) 59 require.NoError(err) 60 } 61 62 func TestStreamClientCorrelationInterceptor(t *testing.T) { 63 require := require.New(t) 64 65 clientInterceptor := StreamClientCorrelationInterceptor(WithClientName(clientName)) 66 67 ctx := correlation.ContextWithCorrelation(context.Background(), correlationID) 68 _, err := clientInterceptor( 69 ctx, 70 nil, 71 nil, 72 methodName, 73 getTestStreamer(require, correlationID, clientName), 74 ) 75 require.NoError(err) 76 }