trpc.group/trpc-go/trpc-go@v1.0.3/filter/filter_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package filter_test 15 16 import ( 17 "context" 18 "sync/atomic" 19 "testing" 20 21 "github.com/stretchr/testify/require" 22 "golang.org/x/sync/errgroup" 23 "trpc.group/trpc-go/trpc-go/filter" 24 "trpc.group/trpc-go/trpc-go/rpcz" 25 ) 26 27 func TestFilterChain(t *testing.T) { 28 ctx := context.Background() 29 req, rsp := "req", "rsp" 30 sc := filter.ServerChain{filter.NoopServerFilter} 31 _, err := sc.Filter(ctx, req, 32 func(ctx context.Context, req interface{}) (rsp interface{}, err error) { 33 return nil, nil 34 }) 35 require.Nil(t, err) 36 cc := filter.ClientChain{filter.NoopClientFilter} 37 require.Nil(t, cc.Filter(ctx, req, rsp, 38 func(ctx context.Context, req, rsp interface{}) error { 39 return nil 40 })) 41 } 42 43 func TestNamedFilter(t *testing.T) { 44 const filterName = "filterName" 45 filter.Register(filterName, filter.NoopServerFilter, filter.NoopClientFilter) 46 require.NotNil(t, filter.GetClient(filterName)) 47 require.NotNil(t, filter.GetServer(filterName)) 48 ctx := context.Background() 49 span, end := rpcz.NewRPCZ(&rpcz.Config{Fraction: 1, Capacity: 1}).NewChild("child") 50 defer end.End() 51 ctx = rpcz.ContextWithSpan(ctx, span) 52 span.SetAttribute(rpcz.TRPCAttributeFilterNames, []string{filterName}) 53 cc := filter.ClientChain{filter.NoopClientFilter} 54 require.Nil(t, cc.Filter(ctx, nil, nil, 55 func(ctx context.Context, req, rsp interface{}) error { return nil })) 56 sc := filter.ServerChain{filter.NoopServerFilter} 57 _, err := sc.Filter(ctx, nil, 58 func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) 59 require.Nil(t, err) 60 } 61 62 func TestChainConcurrentHandle(t *testing.T) { 63 const concurrentN = 4 64 var calledTimes [concurrentN]int32 65 cc := filter.ClientChain{ 66 func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) error { 67 atomic.AddInt32(&calledTimes[0], 1) 68 return f(ctx, req, rsp) 69 }, 70 func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) error { 71 atomic.AddInt32(&calledTimes[1], 1) 72 var eg errgroup.Group 73 for i := 0; i < concurrentN; i++ { 74 eg.Go(func() error { 75 return f(ctx, req, rsp) 76 }) 77 } 78 return eg.Wait() 79 }, 80 func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) (err error) { 81 atomic.AddInt32(&calledTimes[2], 1) 82 return f(ctx, req, rsp) 83 }, 84 func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) (err error) { 85 atomic.AddInt32(&calledTimes[3], 1) 86 return f(ctx, req, rsp) 87 }, 88 } 89 require.Nil(t, cc.Filter(context.Background(), nil, nil, 90 func(ctx context.Context, req, rsp interface{}) (err error) { 91 return nil 92 })) 93 require.Equal(t, int32(1), atomic.LoadInt32(&calledTimes[0])) 94 require.Equal(t, int32(1), atomic.LoadInt32(&calledTimes[1])) 95 require.Equal(t, int32(concurrentN), atomic.LoadInt32(&calledTimes[2])) 96 require.Equal(t, int32(concurrentN), atomic.LoadInt32(&calledTimes[3])) 97 }