github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/go-grpc-middleware/tracing/opentracing/interceptors_test.go (about) 1 // Copyright 2017 Michal Witkowski. All Rights Reserved. 2 // See LICENSE for licensing terms. 3 4 package grpc_opentracing_test 5 6 import ( 7 "errors" 8 "strconv" 9 "strings" 10 "testing" 11 12 "fmt" 13 14 http "github.com/hxx258456/ccgo/gmhttp" 15 16 "io" 17 18 grpc_middleware "github.com/hxx258456/ccgo/go-grpc-middleware" 19 grpc_ctxtags "github.com/hxx258456/ccgo/go-grpc-middleware/tags" 20 grpc_testing "github.com/hxx258456/ccgo/go-grpc-middleware/testing" 21 pb_testproto "github.com/hxx258456/ccgo/go-grpc-middleware/testing/testproto" 22 grpc_opentracing "github.com/hxx258456/ccgo/go-grpc-middleware/tracing/opentracing" 23 "github.com/hxx258456/ccgo/grpc" 24 "github.com/hxx258456/ccgo/grpc/codes" 25 "github.com/hxx258456/ccgo/net/context" 26 "github.com/opentracing/opentracing-go" 27 "github.com/opentracing/opentracing-go/mocktracer" 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 "github.com/stretchr/testify/suite" 31 ) 32 33 var ( 34 goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} 35 fakeInboundTraceId = 1337 36 fakeInboundSpanId = 999 37 ) 38 39 type tracingAssertService struct { 40 pb_testproto.TestServiceServer 41 T *testing.T 42 } 43 44 func (s *tracingAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { 45 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") 46 tags := grpc_ctxtags.Extract(ctx) 47 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") 48 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") 49 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") 50 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true") 51 return s.TestServiceServer.Ping(ctx, ping) 52 } 53 54 func (s *tracingAssertService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { 55 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") 56 return s.TestServiceServer.PingError(ctx, ping) 57 } 58 59 func (s *tracingAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { 60 assert.NotNil(s.T, opentracing.SpanFromContext(stream.Context()), "handlers must have the spancontext in their context, otherwise propagation will fail") 61 tags := grpc_ctxtags.Extract(stream.Context()) 62 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") 63 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") 64 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") 65 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true") 66 return s.TestServiceServer.PingList(ping, stream) 67 } 68 69 func (s *tracingAssertService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { 70 assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") 71 tags := grpc_ctxtags.Extract(ctx) 72 assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") 73 assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") 74 assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") 75 assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "false", "sampled must be set to false") 76 return s.TestServiceServer.PingEmpty(ctx, empty) 77 } 78 79 func TestTaggingSuite(t *testing.T) { 80 mockTracer := mocktracer.New() 81 opts := []grpc_opentracing.Option{ 82 grpc_opentracing.WithTracer(mockTracer), 83 } 84 s := &OpentracingSuite{ 85 mockTracer: mockTracer, 86 InterceptorTestSuite: makeInterceptorTestSuite(t, opts), 87 } 88 suite.Run(t, s) 89 } 90 91 func TestTaggingSuiteJaeger(t *testing.T) { 92 mockTracer := mocktracer.New() 93 mockTracer.RegisterInjector(opentracing.HTTPHeaders, jaegerFormatInjector{}) 94 mockTracer.RegisterExtractor(opentracing.HTTPHeaders, jaegerFormatExtractor{}) 95 opts := []grpc_opentracing.Option{ 96 grpc_opentracing.WithTracer(mockTracer), 97 } 98 s := &OpentracingSuite{ 99 mockTracer: mockTracer, 100 InterceptorTestSuite: makeInterceptorTestSuite(t, opts), 101 } 102 suite.Run(t, s) 103 } 104 105 func makeInterceptorTestSuite(t *testing.T, opts []grpc_opentracing.Option) *grpc_testing.InterceptorTestSuite { 106 107 return &grpc_testing.InterceptorTestSuite{ 108 TestService: &tracingAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}, T: t}, 109 ClientOpts: []grpc.DialOption{ 110 grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor(opts...)), 111 grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor(opts...)), 112 }, 113 ServerOpts: []grpc.ServerOption{ 114 grpc_middleware.WithStreamServerChain( 115 grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), 116 grpc_opentracing.StreamServerInterceptor(opts...)), 117 grpc_middleware.WithUnaryServerChain( 118 grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), 119 grpc_opentracing.UnaryServerInterceptor(opts...)), 120 }, 121 } 122 } 123 124 type OpentracingSuite struct { 125 *grpc_testing.InterceptorTestSuite 126 mockTracer *mocktracer.MockTracer 127 } 128 129 func (s *OpentracingSuite) SetupTest() { 130 s.mockTracer.Reset() 131 } 132 133 func (s *OpentracingSuite) createContextFromFakeHttpRequestParent(ctx context.Context, sampled bool) context.Context { 134 jFlag := 0 135 if sampled { 136 jFlag = 1 137 } 138 139 hdr := http.Header{} 140 hdr.Set("uber-trace-id", fmt.Sprintf("%d:%d:%d:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, jFlag)) 141 hdr.Set("mockpfx-ids-traceid", fmt.Sprint(fakeInboundTraceId)) 142 hdr.Set("mockpfx-ids-spanid", fmt.Sprint(fakeInboundSpanId)) 143 hdr.Set("mockpfx-ids-sampled", fmt.Sprint(sampled)) 144 145 parentSpanContext, err := s.mockTracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(hdr)) 146 require.NoError(s.T(), err, "parsing a fake HTTP request headers shouldn't fail, ever") 147 fakeSpan := s.mockTracer.StartSpan( 148 "/fake/parent/http/request", 149 // this is magical, it attaches the new span to the parent parentSpanContext, and creates an unparented one if empty. 150 opentracing.ChildOf(parentSpanContext), 151 ) 152 fakeSpan.Finish() 153 return opentracing.ContextWithSpan(ctx, fakeSpan) 154 } 155 156 func (s *OpentracingSuite) assertTracesCreated(methodName string) (clientSpan *mocktracer.MockSpan, serverSpan *mocktracer.MockSpan) { 157 spans := s.mockTracer.FinishedSpans() 158 for _, span := range spans { 159 s.T().Logf("span: %v, tags: %v", span, span.Tags()) 160 } 161 require.Len(s.T(), spans, 3, "should record 3 spans: one fake inbound, one client, one server") 162 traceIdAssert := fmt.Sprintf("traceId=%d", fakeInboundTraceId) 163 for _, span := range spans { 164 assert.Contains(s.T(), span.String(), traceIdAssert, "not part of the fake parent trace: %v", span) 165 if span.OperationName == methodName { 166 kind := fmt.Sprintf("%v", span.Tag("span.kind")) 167 if kind == "client" { 168 clientSpan = span 169 } else if kind == "server" { 170 serverSpan = span 171 } 172 assert.EqualValues(s.T(), span.Tag("component"), "gRPC", "span must be tagged with gRPC component") 173 } 174 } 175 require.NotNil(s.T(), clientSpan, "client span must be there") 176 require.NotNil(s.T(), serverSpan, "server span must be there") 177 assert.EqualValues(s.T(), serverSpan.Tag("grpc.request.value"), "something", "grpc_ctxtags must be propagated, in this case ones from request fields") 178 return clientSpan, serverSpan 179 } 180 181 func (s *OpentracingSuite) TestPing_PropagatesTraces() { 182 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true) 183 _, err := s.Client.Ping(ctx, goodPing) 184 require.NoError(s.T(), err, "there must be not be an on a successful call") 185 s.assertTracesCreated("/mwitkow.testproto.TestService/Ping") 186 } 187 188 func (s *OpentracingSuite) TestPing_ClientContextTags() { 189 const name = "opentracing.custom" 190 ctx := grpc_opentracing.ClientAddContextTags( 191 s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true), 192 opentracing.Tags{name: ""}, 193 ) 194 195 _, err := s.Client.Ping(ctx, goodPing) 196 require.NoError(s.T(), err, "there must be not be an on a successful call") 197 198 for _, span := range s.mockTracer.FinishedSpans() { 199 if span.OperationName == "/mwitkow.testproto.TestService/Ping" { 200 kind := fmt.Sprintf("%v", span.Tag("span.kind")) 201 if kind == "client" { 202 assert.Contains(s.T(), span.Tags(), name, "custom opentracing.Tags must be included in context") 203 } 204 } 205 } 206 } 207 208 func (s *OpentracingSuite) TestPingList_PropagatesTraces() { 209 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true) 210 stream, err := s.Client.PingList(ctx, goodPing) 211 require.NoError(s.T(), err, "should not fail on establishing the stream") 212 for { 213 _, err := stream.Recv() 214 if err == io.EOF { 215 break 216 } 217 require.NoError(s.T(), err, "reading stream should not fail") 218 } 219 s.assertTracesCreated("/mwitkow.testproto.TestService/PingList") 220 } 221 222 func (s *OpentracingSuite) TestPingError_PropagatesTraces() { 223 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true) 224 erroringPing := &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(codes.OutOfRange)} 225 _, err := s.Client.PingError(ctx, erroringPing) 226 require.Error(s.T(), err, "there must be an error returned here") 227 clientSpan, serverSpan := s.assertTracesCreated("/mwitkow.testproto.TestService/PingError") 228 assert.Equal(s.T(), true, clientSpan.Tag("error"), "client span needs to be marked as an error") 229 assert.Equal(s.T(), true, serverSpan.Tag("error"), "server span needs to be marked as an error") 230 } 231 232 func (s *OpentracingSuite) TestPingEmpty_NotSampleTraces() { 233 ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), false) 234 _, err := s.Client.PingEmpty(ctx, &pb_testproto.Empty{}) 235 require.NoError(s.T(), err, "there must be not be an on a successful call") 236 } 237 238 type jaegerFormatInjector struct{} 239 240 func (jaegerFormatInjector) Inject(ctx mocktracer.MockSpanContext, carrier interface{}) error { 241 w := carrier.(opentracing.TextMapWriter) 242 flags := 0 243 if ctx.Sampled { 244 flags = 1 245 } 246 w.Set("uber-trace-id", fmt.Sprintf("%d:%d::%d", ctx.TraceID, ctx.SpanID, flags)) 247 248 return nil 249 } 250 251 type jaegerFormatExtractor struct{} 252 253 func (jaegerFormatExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) { 254 rval := mocktracer.MockSpanContext{Sampled: true} 255 reader, ok := carrier.(opentracing.TextMapReader) 256 if !ok { 257 return rval, opentracing.ErrInvalidCarrier 258 } 259 err := reader.ForeachKey(func(key, val string) error { 260 lowerKey := strings.ToLower(key) 261 switch { 262 case lowerKey == "uber-trace-id": 263 parts := strings.Split(val, ":") 264 if len(parts) != 4 { 265 266 return errors.New("invalid trace id format") 267 } 268 traceId, err := strconv.Atoi(parts[0]) 269 if err != nil { 270 return err 271 } 272 rval.TraceID = traceId 273 spanId, err := strconv.Atoi(parts[1]) 274 if err != nil { 275 return err 276 } 277 rval.SpanID = spanId 278 flags, err := strconv.Atoi(parts[3]) 279 if err != nil { 280 return err 281 } 282 rval.Sampled = flags%2 == 1 283 } 284 return nil 285 }) 286 if rval.TraceID == 0 || rval.SpanID == 0 { 287 return rval, opentracing.ErrSpanContextNotFound 288 } 289 if err != nil { 290 return rval, err 291 } 292 return rval, nil 293 }