
     1  // Copyright 2017 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     4  package grpc_opentracing_test
     6  import (
     7  	"errors"
     8  	"strconv"
     9  	"strings"
    10  	"testing"
    12  	"fmt"
    14  	http ""
    16  	"io"
    18  	grpc_middleware ""
    19  	grpc_ctxtags ""
    20  	grpc_testing ""
    21  	pb_testproto ""
    22  	grpc_opentracing ""
    23  	""
    24  	""
    25  	""
    26  	""
    27  	""
    28  	""
    29  	""
    30  	""
    31  )
    33  var (
    34  	goodPing           = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
    35  	fakeInboundTraceId = 1337
    36  	fakeInboundSpanId  = 999
    37  )
    39  type tracingAssertService struct {
    40  	pb_testproto.TestServiceServer
    41  	T *testing.T
    42  }
    44  func (s *tracingAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
    45  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    46  	tags := grpc_ctxtags.Extract(ctx)
    47  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    48  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    49  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    50  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
    51  	return s.TestServiceServer.Ping(ctx, ping)
    52  }
    54  func (s *tracingAssertService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
    55  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    56  	return s.TestServiceServer.PingError(ctx, ping)
    57  }
    59  func (s *tracingAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
    60  	assert.NotNil(s.T, opentracing.SpanFromContext(stream.Context()), "handlers must have the spancontext in their context, otherwise propagation will fail")
    61  	tags := grpc_ctxtags.Extract(stream.Context())
    62  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    63  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    64  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    65  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true")
    66  	return s.TestServiceServer.PingList(ping, stream)
    67  }
    69  func (s *tracingAssertService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) {
    70  	assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail")
    71  	tags := grpc_ctxtags.Extract(ctx)
    72  	assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid")
    73  	assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid")
    74  	assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled")
    75  	assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "false", "sampled must be set to false")
    76  	return s.TestServiceServer.PingEmpty(ctx, empty)
    77  }
    79  func TestTaggingSuite(t *testing.T) {
    80  	mockTracer := mocktracer.New()
    81  	opts := []grpc_opentracing.Option{
    82  		grpc_opentracing.WithTracer(mockTracer),
    83  	}
    84  	s := &OpentracingSuite{
    85  		mockTracer:           mockTracer,
    86  		InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
    87  	}
    88  	suite.Run(t, s)
    89  }
    91  func TestTaggingSuiteJaeger(t *testing.T) {
    92  	mockTracer := mocktracer.New()
    93  	mockTracer.RegisterInjector(opentracing.HTTPHeaders, jaegerFormatInjector{})
    94  	mockTracer.RegisterExtractor(opentracing.HTTPHeaders, jaegerFormatExtractor{})
    95  	opts := []grpc_opentracing.Option{
    96  		grpc_opentracing.WithTracer(mockTracer),
    97  	}
    98  	s := &OpentracingSuite{
    99  		mockTracer:           mockTracer,
   100  		InterceptorTestSuite: makeInterceptorTestSuite(t, opts),
   101  	}
   102  	suite.Run(t, s)
   103  }
   105  func makeInterceptorTestSuite(t *testing.T, opts []grpc_opentracing.Option) *grpc_testing.InterceptorTestSuite {
   107  	return &grpc_testing.InterceptorTestSuite{
   108  		TestService: &tracingAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}, T: t},
   109  		ClientOpts: []grpc.DialOption{
   110  			grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor(opts...)),
   111  			grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor(opts...)),
   112  		},
   113  		ServerOpts: []grpc.ServerOption{
   114  			grpc_middleware.WithStreamServerChain(
   115  				grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
   116  				grpc_opentracing.StreamServerInterceptor(opts...)),
   117  			grpc_middleware.WithUnaryServerChain(
   118  				grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)),
   119  				grpc_opentracing.UnaryServerInterceptor(opts...)),
   120  		},
   121  	}
   122  }
   124  type OpentracingSuite struct {
   125  	*grpc_testing.InterceptorTestSuite
   126  	mockTracer *mocktracer.MockTracer
   127  }
   129  func (s *OpentracingSuite) SetupTest() {
   130  	s.mockTracer.Reset()
   131  }
   133  func (s *OpentracingSuite) createContextFromFakeHttpRequestParent(ctx context.Context, sampled bool) context.Context {
   134  	jFlag := 0
   135  	if sampled {
   136  		jFlag = 1
   137  	}
   139  	hdr := http.Header{}
   140  	hdr.Set("uber-trace-id", fmt.Sprintf("%d:%d:%d:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, jFlag))
   141  	hdr.Set("mockpfx-ids-traceid", fmt.Sprint(fakeInboundTraceId))
   142  	hdr.Set("mockpfx-ids-spanid", fmt.Sprint(fakeInboundSpanId))
   143  	hdr.Set("mockpfx-ids-sampled", fmt.Sprint(sampled))
   145  	parentSpanContext, err := s.mockTracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(hdr))
   146  	require.NoError(s.T(), err, "parsing a fake HTTP request headers shouldn't fail, ever")
   147  	fakeSpan := s.mockTracer.StartSpan(
   148  		"/fake/parent/http/request",
   149  		// this is magical, it attaches the new span to the parent parentSpanContext, and creates an unparented one if empty.
   150  		opentracing.ChildOf(parentSpanContext),
   151  	)
   152  	fakeSpan.Finish()
   153  	return opentracing.ContextWithSpan(ctx, fakeSpan)
   154  }
   156  func (s *OpentracingSuite) assertTracesCreated(methodName string) (clientSpan *mocktracer.MockSpan, serverSpan *mocktracer.MockSpan) {
   157  	spans := s.mockTracer.FinishedSpans()
   158  	for _, span := range spans {
   159  		s.T().Logf("span: %v, tags: %v", span, span.Tags())
   160  	}
   161  	require.Len(s.T(), spans, 3, "should record 3 spans: one fake inbound, one client, one server")
   162  	traceIdAssert := fmt.Sprintf("traceId=%d", fakeInboundTraceId)
   163  	for _, span := range spans {
   164  		assert.Contains(s.T(), span.String(), traceIdAssert, "not part of the fake parent trace: %v", span)
   165  		if span.OperationName == methodName {
   166  			kind := fmt.Sprintf("%v", span.Tag("span.kind"))
   167  			if kind == "client" {
   168  				clientSpan = span
   169  			} else if kind == "server" {
   170  				serverSpan = span
   171  			}
   172  			assert.EqualValues(s.T(), span.Tag("component"), "gRPC", "span must be tagged with gRPC component")
   173  		}
   174  	}
   175  	require.NotNil(s.T(), clientSpan, "client span must be there")
   176  	require.NotNil(s.T(), serverSpan, "server span must be there")
   177  	assert.EqualValues(s.T(), serverSpan.Tag("grpc.request.value"), "something", "grpc_ctxtags must be propagated, in this case ones from request fields")
   178  	return clientSpan, serverSpan
   179  }
   181  func (s *OpentracingSuite) TestPing_PropagatesTraces() {
   182  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true)
   183  	_, err := s.Client.Ping(ctx, goodPing)
   184  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   185  	s.assertTracesCreated("/mwitkow.testproto.TestService/Ping")
   186  }
   188  func (s *OpentracingSuite) TestPing_ClientContextTags() {
   189  	const name = "opentracing.custom"
   190  	ctx := grpc_opentracing.ClientAddContextTags(
   191  		s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true),
   192  		opentracing.Tags{name: ""},
   193  	)
   195  	_, err := s.Client.Ping(ctx, goodPing)
   196  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   198  	for _, span := range s.mockTracer.FinishedSpans() {
   199  		if span.OperationName == "/mwitkow.testproto.TestService/Ping" {
   200  			kind := fmt.Sprintf("%v", span.Tag("span.kind"))
   201  			if kind == "client" {
   202  				assert.Contains(s.T(), span.Tags(), name, "custom opentracing.Tags must be included in context")
   203  			}
   204  		}
   205  	}
   206  }
   208  func (s *OpentracingSuite) TestPingList_PropagatesTraces() {
   209  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true)
   210  	stream, err := s.Client.PingList(ctx, goodPing)
   211  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   212  	for {
   213  		_, err := stream.Recv()
   214  		if err == io.EOF {
   215  			break
   216  		}
   217  		require.NoError(s.T(), err, "reading stream should not fail")
   218  	}
   219  	s.assertTracesCreated("/mwitkow.testproto.TestService/PingList")
   220  }
   222  func (s *OpentracingSuite) TestPingError_PropagatesTraces() {
   223  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true)
   224  	erroringPing := &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(codes.OutOfRange)}
   225  	_, err := s.Client.PingError(ctx, erroringPing)
   226  	require.Error(s.T(), err, "there must be an error returned here")
   227  	clientSpan, serverSpan := s.assertTracesCreated("/mwitkow.testproto.TestService/PingError")
   228  	assert.Equal(s.T(), true, clientSpan.Tag("error"), "client span needs to be marked as an error")
   229  	assert.Equal(s.T(), true, serverSpan.Tag("error"), "server span needs to be marked as an error")
   230  }
   232  func (s *OpentracingSuite) TestPingEmpty_NotSampleTraces() {
   233  	ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), false)
   234  	_, err := s.Client.PingEmpty(ctx, &pb_testproto.Empty{})
   235  	require.NoError(s.T(), err, "there must be not be an on a successful call")
   236  }
   238  type jaegerFormatInjector struct{}
   240  func (jaegerFormatInjector) Inject(ctx mocktracer.MockSpanContext, carrier interface{}) error {
   241  	w := carrier.(opentracing.TextMapWriter)
   242  	flags := 0
   243  	if ctx.Sampled {
   244  		flags = 1
   245  	}
   246  	w.Set("uber-trace-id", fmt.Sprintf("%d:%d::%d", ctx.TraceID, ctx.SpanID, flags))
   248  	return nil
   249  }
   251  type jaegerFormatExtractor struct{}
   253  func (jaegerFormatExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) {
   254  	rval := mocktracer.MockSpanContext{Sampled: true}
   255  	reader, ok := carrier.(opentracing.TextMapReader)
   256  	if !ok {
   257  		return rval, opentracing.ErrInvalidCarrier
   258  	}
   259  	err := reader.ForeachKey(func(key, val string) error {
   260  		lowerKey := strings.ToLower(key)
   261  		switch {
   262  		case lowerKey == "uber-trace-id":
   263  			parts := strings.Split(val, ":")
   264  			if len(parts) != 4 {
   266  				return errors.New("invalid trace id format")
   267  			}
   268  			traceId, err := strconv.Atoi(parts[0])
   269  			if err != nil {
   270  				return err
   271  			}
   272  			rval.TraceID = traceId
   273  			spanId, err := strconv.Atoi(parts[1])
   274  			if err != nil {
   275  				return err
   276  			}
   277  			rval.SpanID = spanId
   278  			flags, err := strconv.Atoi(parts[3])
   279  			if err != nil {
   280  				return err
   281  			}
   282  			rval.Sampled = flags%2 == 1
   283  		}
   284  		return nil
   285  	})
   286  	if rval.TraceID == 0 || rval.SpanID == 0 {
   287  		return rval, opentracing.ErrSpanContextNotFound
   288  	}
   289  	if err != nil {
   290  		return rval, err
   291  	}
   292  	return rval, nil
   293  }