gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/client_interceptors_test.go (about)

     1  package grpccorrelation
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  	"gitlab.com/gitlab-org/labkit/correlation"
     9  	"google.golang.org/grpc"
    10  	"google.golang.org/grpc/metadata"
    11  )
    12  
    13  const (
    14  	correlationID = "CORRELATION_ID"
    15  	clientName    = "CLIENT_NAME"
    16  	methodName    = "METHOD_NAME"
    17  )
    18  
    19  func verifyContextMetadata(ctx context.Context, require *require.Assertions, expCorrelationID, expClientName string) {
    20  	md, ok := metadata.FromOutgoingContext(ctx)
    21  	require.True(ok)
    22  	ids := md.Get(metadataCorrelatorKey)
    23  	require.Less(0, len(ids))
    24  	require.Equal(expCorrelationID, ids[0])
    25  
    26  	clientNames := md.Get(metadataClientNameKey)
    27  	require.Less(0, len(clientNames))
    28  	require.Equal(expClientName, clientNames[0])
    29  }
    30  
    31  func getTestUnaryInvoker(require *require.Assertions, expCorrelationID, expClientName string) grpc.UnaryInvoker {
    32  	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
    33  		verifyContextMetadata(ctx, require, expCorrelationID, expClientName)
    34  		return nil
    35  	}
    36  }
    37  
    38  func getTestStreamer(require *require.Assertions, expCorrelationID, expClientName string) grpc.Streamer {
    39  	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    40  		verifyContextMetadata(ctx, require, expCorrelationID, expClientName)
    41  		return nil, nil
    42  	}
    43  }
    44  
    45  func TestUnaryClientCorrelationInterceptor(t *testing.T) {
    46  	require := require.New(t)
    47  
    48  	clientInterceptor := UnaryClientCorrelationInterceptor(WithClientName(clientName))
    49  
    50  	ctx := correlation.ContextWithCorrelation(context.Background(), correlationID)
    51  	err := clientInterceptor(
    52  		ctx,
    53  		methodName,
    54  		nil,
    55  		nil,
    56  		nil,
    57  		getTestUnaryInvoker(require, correlationID, clientName),
    58  	)
    59  	require.NoError(err)
    60  }
    61  
    62  func TestStreamClientCorrelationInterceptor(t *testing.T) {
    63  	require := require.New(t)
    64  
    65  	clientInterceptor := StreamClientCorrelationInterceptor(WithClientName(clientName))
    66  
    67  	ctx := correlation.ContextWithCorrelation(context.Background(), correlationID)
    68  	_, err := clientInterceptor(
    69  		ctx,
    70  		nil,
    71  		nil,
    72  		methodName,
    73  		getTestStreamer(require, correlationID, clientName),
    74  	)
    75  	require.NoError(err)
    76  }