trpc.group/trpc-go/trpc-go@v1.0.3/rpcz/context_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package rpcz
    15  
    16  import (
    17  	"context"
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestSpanFromContext(t *testing.T) {
    24  	t.Run("empty context", func(t *testing.T) {
    25  		require.Panicsf(t, func() { SpanFromContext(nil) }, "should panic because of nil pointer dereference")
    26  	})
    27  	t.Run("background context", func(t *testing.T) {
    28  		require.Equal(t, GlobalRPCZ, SpanFromContext(context.Background()))
    29  	})
    30  	rootSpan := newSpan("root span", SpanID(1), nil)
    31  	t.Run("context has root span", func(t *testing.T) {
    32  		ctx := ContextWithSpan(context.Background(), rootSpan)
    33  		require.Equal(t, rootSpan, SpanFromContext(ctx))
    34  	})
    35  	t.Run("context has child span", func(t *testing.T) {
    36  		childSpan := newSpan("child span", SpanID(2), nil)
    37  		ctx := ContextWithSpan(context.Background(), childSpan)
    38  		require.Equal(t, childSpan, SpanFromContext(ctx))
    39  	})
    40  }
    41  
    42  func TestCurrentSpanKey(t *testing.T) {
    43  	type notSpanKey struct{}
    44  	t.Run("key is not spanKey", func(t *testing.T) {
    45  		ctx := context.Background()
    46  		ctx = context.WithValue(ctx, notSpanKey{}, &span{})
    47  		s := SpanFromContext(ctx)
    48  		require.IsType(t, &RPCZ{}, s)
    49  	})
    50  	t.Run("key is spanKey", func(t *testing.T) {
    51  		ctx := context.Background()
    52  		ctx = context.WithValue(ctx, spanKey{}, &span{})
    53  		s := SpanFromContext(ctx)
    54  		require.IsType(t, &span{}, s)
    55  	})
    56  }
    57  
    58  func TestNewSpanContext(t *testing.T) {
    59  	t.Run("new noop span", func(t *testing.T) {
    60  		ctx := context.Background()
    61  		span, _, newCtx1 := NewSpanContext(ctx, "server")
    62  		require.Equal(t, ctx, newCtx1)
    63  		require.Equal(t, GlobalRPCZ, span)
    64  
    65  		NewSpanContext(ctx, "filter")
    66  		span, _, newCtx2 := NewSpanContext(newCtx1, "server")
    67  		require.Equal(t, newCtx1, newCtx2)
    68  		require.Equal(t, GlobalRPCZ, span)
    69  	})
    70  
    71  	t.Run("new *span", func(t *testing.T) {
    72  		ctx := ContextWithSpan(context.Background(), &span{})
    73  		s, _, newCtx1 := NewSpanContext(ctx, "server")
    74  		require.NotEqual(t, ctx, newCtx1)
    75  		require.NotEqual(t, noopSpan{}, s)
    76  
    77  		NewSpanContext(ctx, "filter")
    78  		s, _, newCtx2 := NewSpanContext(newCtx1, "server")
    79  		require.NotEqual(t, newCtx1, newCtx2)
    80  		require.NotEqual(t, noopSpan{}, s)
    81  	})
    82  }