trpc.group/trpc-go/trpc-go@v1.0.3/filter/filter_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package filter_test
    15  
    16  import (
    17  	"context"
    18  	"sync/atomic"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/require"
    22  	"golang.org/x/sync/errgroup"
    23  	"trpc.group/trpc-go/trpc-go/filter"
    24  	"trpc.group/trpc-go/trpc-go/rpcz"
    25  )
    26  
    27  func TestFilterChain(t *testing.T) {
    28  	ctx := context.Background()
    29  	req, rsp := "req", "rsp"
    30  	sc := filter.ServerChain{filter.NoopServerFilter}
    31  	_, err := sc.Filter(ctx, req,
    32  		func(ctx context.Context, req interface{}) (rsp interface{}, err error) {
    33  			return nil, nil
    34  		})
    35  	require.Nil(t, err)
    36  	cc := filter.ClientChain{filter.NoopClientFilter}
    37  	require.Nil(t, cc.Filter(ctx, req, rsp,
    38  		func(ctx context.Context, req, rsp interface{}) error {
    39  			return nil
    40  		}))
    41  }
    42  
    43  func TestNamedFilter(t *testing.T) {
    44  	const filterName = "filterName"
    45  	filter.Register(filterName, filter.NoopServerFilter, filter.NoopClientFilter)
    46  	require.NotNil(t, filter.GetClient(filterName))
    47  	require.NotNil(t, filter.GetServer(filterName))
    48  	ctx := context.Background()
    49  	span, end := rpcz.NewRPCZ(&rpcz.Config{Fraction: 1, Capacity: 1}).NewChild("child")
    50  	defer end.End()
    51  	ctx = rpcz.ContextWithSpan(ctx, span)
    52  	span.SetAttribute(rpcz.TRPCAttributeFilterNames, []string{filterName})
    53  	cc := filter.ClientChain{filter.NoopClientFilter}
    54  	require.Nil(t, cc.Filter(ctx, nil, nil,
    55  		func(ctx context.Context, req, rsp interface{}) error { return nil }))
    56  	sc := filter.ServerChain{filter.NoopServerFilter}
    57  	_, err := sc.Filter(ctx, nil,
    58  		func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil })
    59  	require.Nil(t, err)
    60  }
    61  
    62  func TestChainConcurrentHandle(t *testing.T) {
    63  	const concurrentN = 4
    64  	var calledTimes [concurrentN]int32
    65  	cc := filter.ClientChain{
    66  		func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) error {
    67  			atomic.AddInt32(&calledTimes[0], 1)
    68  			return f(ctx, req, rsp)
    69  		},
    70  		func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) error {
    71  			atomic.AddInt32(&calledTimes[1], 1)
    72  			var eg errgroup.Group
    73  			for i := 0; i < concurrentN; i++ {
    74  				eg.Go(func() error {
    75  					return f(ctx, req, rsp)
    76  				})
    77  			}
    78  			return eg.Wait()
    79  		},
    80  		func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) (err error) {
    81  			atomic.AddInt32(&calledTimes[2], 1)
    82  			return f(ctx, req, rsp)
    83  		},
    84  		func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) (err error) {
    85  			atomic.AddInt32(&calledTimes[3], 1)
    86  			return f(ctx, req, rsp)
    87  		},
    88  	}
    89  	require.Nil(t, cc.Filter(context.Background(), nil, nil,
    90  		func(ctx context.Context, req, rsp interface{}) (err error) {
    91  			return nil
    92  		}))
    93  	require.Equal(t, int32(1), atomic.LoadInt32(&calledTimes[0]))
    94  	require.Equal(t, int32(1), atomic.LoadInt32(&calledTimes[1]))
    95  	require.Equal(t, int32(concurrentN), atomic.LoadInt32(&calledTimes[2]))
    96  	require.Equal(t, int32(concurrentN), atomic.LoadInt32(&calledTimes[3]))
    97  }