gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/server_interceptors_test.go (about)

     1  package grpccorrelation
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"github.com/stretchr/testify/require"
     9  	"gitlab.com/gitlab-org/labkit/correlation"
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/metadata"
    12  )
    13  
    14  var (
    15  	_ grpc.ServerTransportStream = (*mockServerTransportStream)(nil)
    16  	_ grpc.ServerStream          = (*mockServerStream)(nil)
    17  )
    18  
    19  type tcType struct {
    20  	name               string
    21  	md                 metadata.MD
    22  	withoutPropagation bool
    23  
    24  	expectRandom       bool
    25  	expectedClientName string
    26  }
    27  
    28  func TestServerCorrelationInterceptors(t *testing.T) {
    29  	tests := []tcType{
    30  		{
    31  			name: "default",
    32  			md: metadata.Pairs(
    33  				metadataCorrelatorKey,
    34  				correlationID,
    35  				metadataClientNameKey,
    36  				clientName,
    37  			),
    38  			expectedClientName: clientName,
    39  		},
    40  		{
    41  			name: "id present but not trusted",
    42  			md: metadata.Pairs(
    43  				metadataCorrelatorKey,
    44  				correlationID,
    45  			),
    46  			withoutPropagation: true,
    47  			expectRandom:       true,
    48  		},
    49  		{
    50  			name: "id present, trusted but empty",
    51  			md: metadata.Pairs(
    52  				metadataCorrelatorKey,
    53  				"",
    54  			),
    55  			withoutPropagation: true,
    56  			expectRandom:       true,
    57  		},
    58  		{
    59  			name:               "id absent and not trusted",
    60  			md:                 metadata.Pairs(),
    61  			withoutPropagation: true,
    62  			expectRandom:       true,
    63  		},
    64  		{
    65  			name:         "id absent and trusted",
    66  			md:           metadata.Pairs(),
    67  			expectRandom: true,
    68  		},
    69  		{
    70  			name:         "no metadata",
    71  			md:           nil,
    72  			expectRandom: true,
    73  		},
    74  	}
    75  
    76  	t.Run("unary", func(t *testing.T) {
    77  		for _, tc := range tests {
    78  			t.Run(tc.name, testUnaryServerCorrelationInterceptor(tc, false))
    79  			t.Run(tc.name+" (reverse)", testUnaryServerCorrelationInterceptor(tc, true))
    80  		}
    81  	})
    82  	t.Run("streaming", func(t *testing.T) {
    83  		for _, tc := range tests {
    84  			t.Run(tc.name, testStreamingServerCorrelationInterceptor(tc, false))
    85  			t.Run(tc.name+" (reverse)", testStreamingServerCorrelationInterceptor(tc, true))
    86  		}
    87  	})
    88  }
    89  
    90  func testUnaryServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) {
    91  	return func(t *testing.T) {
    92  		t.Helper()
    93  
    94  		sts := &mockServerTransportStream{}
    95  		ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
    96  		if tc.md != nil {
    97  			ctx = metadata.NewIncomingContext(ctx, tc.md)
    98  		}
    99  		interceptor := UnaryServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...)
   100  		_, err := interceptor(
   101  			ctx,
   102  			nil,
   103  			nil,
   104  			func(ctx context.Context, req interface{}) (interface{}, error) {
   105  				testServerCtx(ctx, t, tc, reverseCorrelationID, sts.header)
   106  				return nil, nil
   107  			},
   108  		)
   109  		require.NoError(t, err)
   110  	}
   111  }
   112  
   113  func testStreamingServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) {
   114  	return func(t *testing.T) {
   115  		t.Helper()
   116  
   117  		ctx := context.Background()
   118  		if tc.md != nil {
   119  			ctx = metadata.NewIncomingContext(ctx, tc.md)
   120  		}
   121  		ss := &mockServerStream{
   122  			ctx: ctx,
   123  		}
   124  		interceptor := StreamServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...)
   125  		err := interceptor(
   126  			nil,
   127  			ss,
   128  			nil,
   129  			func(srv interface{}, stream grpc.ServerStream) error {
   130  				testServerCtx(stream.Context(), t, tc, reverseCorrelationID, ss.header)
   131  				return nil
   132  			},
   133  		)
   134  		require.NoError(t, err)
   135  	}
   136  }
   137  
   138  func constructServerOpts(tc tcType, reverseCorrelationID bool) []ServerCorrelationInterceptorOption {
   139  	var opts []ServerCorrelationInterceptorOption
   140  	if tc.withoutPropagation {
   141  		opts = append(opts, WithoutPropagation())
   142  	}
   143  	if reverseCorrelationID {
   144  		opts = append(opts, WithReversePropagation())
   145  	}
   146  	return opts
   147  }
   148  
   149  func testServerCtx(ctx context.Context, t *testing.T, tc tcType, reverseCorrelationID bool, header metadata.MD) {
   150  	t.Helper()
   151  
   152  	actualID := correlation.ExtractFromContext(ctx)
   153  	if tc.expectRandom {
   154  		assert.NotEqual(t, correlationID, actualID)
   155  		assert.NotEmpty(t, actualID)
   156  	} else {
   157  		assert.Equal(t, correlationID, actualID)
   158  	}
   159  	vals := header.Get(metadataCorrelatorKey)
   160  	if reverseCorrelationID {
   161  		assert.Equal(t, []string{actualID}, vals)
   162  	} else {
   163  		assert.Empty(t, vals)
   164  	}
   165  	assert.Equal(t, tc.expectedClientName, correlation.ExtractClientNameFromContext(ctx))
   166  }
   167  
   168  type mockServerTransportStream struct {
   169  	header metadata.MD
   170  }
   171  
   172  func (s *mockServerTransportStream) Method() string {
   173  	panic("implement me")
   174  }
   175  
   176  func (s *mockServerTransportStream) SetHeader(md metadata.MD) error {
   177  	s.header = metadata.Join(s.header, md)
   178  	return nil
   179  }
   180  
   181  func (s *mockServerTransportStream) SendHeader(md metadata.MD) error {
   182  	panic("implement me")
   183  }
   184  
   185  func (s *mockServerTransportStream) SetTrailer(md metadata.MD) error {
   186  	panic("implement me")
   187  }
   188  
   189  type mockServerStream struct {
   190  	ctx    context.Context
   191  	header metadata.MD
   192  }
   193  
   194  func (s *mockServerStream) SetHeader(md metadata.MD) error {
   195  	s.header = metadata.Join(s.header, md)
   196  	return nil
   197  }
   198  
   199  func (s *mockServerStream) SendHeader(md metadata.MD) error {
   200  	panic("implement me")
   201  }
   202  
   203  func (s *mockServerStream) SetTrailer(md metadata.MD) {
   204  	panic("implement me")
   205  }
   206  
   207  func (s *mockServerStream) Context() context.Context {
   208  	return s.ctx
   209  }
   210  
   211  func (s *mockServerStream) SendMsg(m interface{}) error {
   212  	panic("implement me")
   213  }
   214  
   215  func (s *mockServerStream) RecvMsg(m interface{}) error {
   216  	panic("implement me")
   217  }