github.com/m3db/m3@v1.5.0/src/x/opentracing/context_test.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package opentracing
    22  
    23  import (
    24  	"context"
    25  	"testing"
    26  
    27  	"github.com/opentracing/opentracing-go"
    28  	"github.com/opentracing/opentracing-go/mocktracer"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  )
    32  
    33  func TestStartSpanFromContext(t *testing.T) {
    34  	t.Run("uses local tracer if available", func(t *testing.T) {
    35  		mtr := mocktracer.New()
    36  		defer mockoutGlobalTracer(opentracing.NoopTracer{})()
    37  
    38  		rootSp := mtr.StartSpan("root")
    39  		ctx := opentracing.ContextWithSpan(context.Background(), rootSp)
    40  
    41  		childSp, ctx := StartSpanFromContext(ctx, "child")
    42  
    43  		rootSp.Finish()
    44  		childSp.Finish()
    45  
    46  		assertHasSpans(t, mtr, []string{"root", "child"})
    47  
    48  		ctxSpan := opentracing.SpanFromContext(ctx)
    49  		require.IsType(t, (*mocktracer.MockSpan)(nil), ctxSpan)
    50  		assert.Equal(t, ctxSpan.(*mocktracer.MockSpan).OperationName, "child",
    51  			"should set span on context to child span")
    52  
    53  	})
    54  
    55  	t.Run("uses global tracer if no parent span", func(t *testing.T) {
    56  		mtr := mocktracer.New()
    57  		defer mockoutGlobalTracer(mtr)()
    58  
    59  		sp, ctx := StartSpanFromContext(context.Background(), "foo")
    60  		sp.Finish()
    61  
    62  		assertHasSpans(t, mtr, []string{"foo"})
    63  		assert.NotNil(t, opentracing.SpanFromContext(ctx))
    64  	})
    65  }
    66  
    67  func TestStartSpanFromContextOrRoot(t *testing.T) {
    68  	t.Run("uses noop tracer if nothing passed in", func(t *testing.T) {
    69  		assert.Equal(t, opentracing.NoopTracer{}.StartSpan(""),
    70  			SpanFromContextOrNoop(context.
    71  				Background()))
    72  	})
    73  
    74  	t.Run("returns span if span attached to context", func(t *testing.T) {
    75  		mt := mocktracer.New()
    76  		root := mt.StartSpan("root")
    77  		ctx := opentracing.ContextWithSpan(context.Background(), root)
    78  
    79  		assert.Equal(t, root, SpanFromContextOrNoop(ctx))
    80  	})
    81  
    82  }
    83  
    84  func mockoutGlobalTracer(mtr opentracing.Tracer) func() {
    85  	oldGGT := getGlobalTracer
    86  	getGlobalTracer = func() opentracing.Tracer {
    87  		return mtr
    88  	}
    89  
    90  	return func() {
    91  		getGlobalTracer = oldGGT
    92  	}
    93  }
    94  
    95  func assertHasSpans(t *testing.T, mtr *mocktracer.MockTracer, opNames []string) {
    96  	spans := mtr.FinishedSpans()
    97  	actualOpNames := make([]string, len(spans))
    98  
    99  	for i, sp := range spans {
   100  		actualOpNames[i] = sp.OperationName
   101  	}
   102  
   103  	assert.Equal(t, opNames, actualOpNames)
   104  }