gitlab.com/gitlab-org/labkit@v1.21.0/correlation/grpc/server_interceptors_test.go (about) 1 package grpccorrelation 2 3 import ( 4 "context" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "github.com/stretchr/testify/require" 9 "gitlab.com/gitlab-org/labkit/correlation" 10 "google.golang.org/grpc" 11 "google.golang.org/grpc/metadata" 12 ) 13 14 var ( 15 _ grpc.ServerTransportStream = (*mockServerTransportStream)(nil) 16 _ grpc.ServerStream = (*mockServerStream)(nil) 17 ) 18 19 type tcType struct { 20 name string 21 md metadata.MD 22 withoutPropagation bool 23 24 expectRandom bool 25 expectedClientName string 26 } 27 28 func TestServerCorrelationInterceptors(t *testing.T) { 29 tests := []tcType{ 30 { 31 name: "default", 32 md: metadata.Pairs( 33 metadataCorrelatorKey, 34 correlationID, 35 metadataClientNameKey, 36 clientName, 37 ), 38 expectedClientName: clientName, 39 }, 40 { 41 name: "id present but not trusted", 42 md: metadata.Pairs( 43 metadataCorrelatorKey, 44 correlationID, 45 ), 46 withoutPropagation: true, 47 expectRandom: true, 48 }, 49 { 50 name: "id present, trusted but empty", 51 md: metadata.Pairs( 52 metadataCorrelatorKey, 53 "", 54 ), 55 withoutPropagation: true, 56 expectRandom: true, 57 }, 58 { 59 name: "id absent and not trusted", 60 md: metadata.Pairs(), 61 withoutPropagation: true, 62 expectRandom: true, 63 }, 64 { 65 name: "id absent and trusted", 66 md: metadata.Pairs(), 67 expectRandom: true, 68 }, 69 { 70 name: "no metadata", 71 md: nil, 72 expectRandom: true, 73 }, 74 } 75 76 t.Run("unary", func(t *testing.T) { 77 for _, tc := range tests { 78 t.Run(tc.name, testUnaryServerCorrelationInterceptor(tc, false)) 79 t.Run(tc.name+" (reverse)", testUnaryServerCorrelationInterceptor(tc, true)) 80 } 81 }) 82 t.Run("streaming", func(t *testing.T) { 83 for _, tc := range tests { 84 t.Run(tc.name, testStreamingServerCorrelationInterceptor(tc, false)) 85 t.Run(tc.name+" (reverse)", testStreamingServerCorrelationInterceptor(tc, true)) 86 } 87 }) 88 } 89 90 func testUnaryServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) { 91 return func(t *testing.T) { 92 t.Helper() 93 94 sts := &mockServerTransportStream{} 95 ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts) 96 if tc.md != nil { 97 ctx = metadata.NewIncomingContext(ctx, tc.md) 98 } 99 interceptor := UnaryServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...) 100 _, err := interceptor( 101 ctx, 102 nil, 103 nil, 104 func(ctx context.Context, req interface{}) (interface{}, error) { 105 testServerCtx(ctx, t, tc, reverseCorrelationID, sts.header) 106 return nil, nil 107 }, 108 ) 109 require.NoError(t, err) 110 } 111 } 112 113 func testStreamingServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) { 114 return func(t *testing.T) { 115 t.Helper() 116 117 ctx := context.Background() 118 if tc.md != nil { 119 ctx = metadata.NewIncomingContext(ctx, tc.md) 120 } 121 ss := &mockServerStream{ 122 ctx: ctx, 123 } 124 interceptor := StreamServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...) 125 err := interceptor( 126 nil, 127 ss, 128 nil, 129 func(srv interface{}, stream grpc.ServerStream) error { 130 testServerCtx(stream.Context(), t, tc, reverseCorrelationID, ss.header) 131 return nil 132 }, 133 ) 134 require.NoError(t, err) 135 } 136 } 137 138 func constructServerOpts(tc tcType, reverseCorrelationID bool) []ServerCorrelationInterceptorOption { 139 var opts []ServerCorrelationInterceptorOption 140 if tc.withoutPropagation { 141 opts = append(opts, WithoutPropagation()) 142 } 143 if reverseCorrelationID { 144 opts = append(opts, WithReversePropagation()) 145 } 146 return opts 147 } 148 149 func testServerCtx(ctx context.Context, t *testing.T, tc tcType, reverseCorrelationID bool, header metadata.MD) { 150 t.Helper() 151 152 actualID := correlation.ExtractFromContext(ctx) 153 if tc.expectRandom { 154 assert.NotEqual(t, correlationID, actualID) 155 assert.NotEmpty(t, actualID) 156 } else { 157 assert.Equal(t, correlationID, actualID) 158 } 159 vals := header.Get(metadataCorrelatorKey) 160 if reverseCorrelationID { 161 assert.Equal(t, []string{actualID}, vals) 162 } else { 163 assert.Empty(t, vals) 164 } 165 assert.Equal(t, tc.expectedClientName, correlation.ExtractClientNameFromContext(ctx)) 166 } 167 168 type mockServerTransportStream struct { 169 header metadata.MD 170 } 171 172 func (s *mockServerTransportStream) Method() string { 173 panic("implement me") 174 } 175 176 func (s *mockServerTransportStream) SetHeader(md metadata.MD) error { 177 s.header = metadata.Join(s.header, md) 178 return nil 179 } 180 181 func (s *mockServerTransportStream) SendHeader(md metadata.MD) error { 182 panic("implement me") 183 } 184 185 func (s *mockServerTransportStream) SetTrailer(md metadata.MD) error { 186 panic("implement me") 187 } 188 189 type mockServerStream struct { 190 ctx context.Context 191 header metadata.MD 192 } 193 194 func (s *mockServerStream) SetHeader(md metadata.MD) error { 195 s.header = metadata.Join(s.header, md) 196 return nil 197 } 198 199 func (s *mockServerStream) SendHeader(md metadata.MD) error { 200 panic("implement me") 201 } 202 203 func (s *mockServerStream) SetTrailer(md metadata.MD) { 204 panic("implement me") 205 } 206 207 func (s *mockServerStream) Context() context.Context { 208 return s.ctx 209 } 210 211 func (s *mockServerStream) SendMsg(m interface{}) error { 212 panic("implement me") 213 } 214 215 func (s *mockServerStream) RecvMsg(m interface{}) error { 216 panic("implement me") 217 }