trpc.group/trpc-go/trpc-go@v1.0.3/client/options_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package client_test
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"reflect"
    20  	"runtime"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/require"
    25  
    26  	trpc "trpc.group/trpc-go/trpc-go"
    27  	"trpc.group/trpc-go/trpc-go/client"
    28  	"trpc.group/trpc-go/trpc-go/codec"
    29  	"trpc.group/trpc-go/trpc-go/filter"
    30  	"trpc.group/trpc-go/trpc-go/http"
    31  	"trpc.group/trpc-go/trpc-go/naming/registry"
    32  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    33  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    34  	"trpc.group/trpc-go/trpc-go/transport"
    35  )
    36  
    37  func TestSelectOptions(t *testing.T) {
    38  	opts := &client.Options{}
    39  
    40  	var selectOptionNum int
    41  	var callOptionNum int
    42  
    43  	// WithCallerServiceName sets service name of the service itself
    44  	o := client.WithCallerServiceName("trpc.test.helloworld1")
    45  	selectOptionNum++
    46  	o(opts)
    47  	require.Equal(t, "trpc.test.helloworld1", opts.CallerServiceName)
    48  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    49  
    50  	o = client.WithCallerNamespace("Production")
    51  	selectOptionNum++
    52  	o(opts)
    53  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    54  
    55  	o = client.WithCallerEnvName("test")
    56  	selectOptionNum++
    57  	o(opts)
    58  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    59  
    60  	o = client.WithCalleeEnvName("test")
    61  	selectOptionNum++
    62  	o(opts)
    63  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    64  
    65  	o = client.WithCallerSetName("set")
    66  	selectOptionNum++
    67  	o(opts)
    68  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    69  
    70  	o = client.WithCalleeSetName("set")
    71  	selectOptionNum++
    72  	o(opts)
    73  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    74  
    75  	o = client.WithCalleeMethod("func")
    76  	o(opts)
    77  	require.Equal(t, "func", opts.CalleeMethod)
    78  
    79  	o = client.WithCallerMetadata("tag", "data")
    80  	selectOptionNum++
    81  	o(opts)
    82  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    83  
    84  	o = client.WithCalleeMetadata("tag", "data")
    85  	selectOptionNum++
    86  	o(opts)
    87  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
    88  
    89  	o = client.WithPassword("passwd")
    90  	callOptionNum++
    91  	o(opts)
    92  	require.Equal(t, callOptionNum, len(opts.CallOptions))
    93  
    94  	o = client.WithConnectionMode(transport.Connected)
    95  	callOptionNum++
    96  	o(opts)
    97  	require.Equal(t, callOptionNum, len(opts.CallOptions))
    98  
    99  	o = client.WithSendOnly()
   100  	callOptionNum++
   101  	o(opts)
   102  	require.Equal(t, opts.CallType, codec.SendOnly)
   103  	require.Equal(t, callOptionNum, len(opts.CallOptions))
   104  
   105  	o = client.WithTLS("client.cert", "client.key", "ca.pem", "servername")
   106  	callOptionNum++
   107  	o(opts)
   108  	require.Equal(t, callOptionNum, len(opts.CallOptions))
   109  
   110  	o = client.WithDisableConnectionPool()
   111  	callOptionNum++
   112  	o(opts)
   113  	require.Equal(t, callOptionNum, len(opts.CallOptions))
   114  
   115  	o = client.WithDiscoveryName("polaris")
   116  	selectOptionNum++
   117  	o(opts)
   118  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   119  
   120  	o = client.WithServiceRouterName("polaris")
   121  	selectOptionNum++
   122  	o(opts)
   123  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   124  
   125  	o = client.WithBalancerName("polaris")
   126  	selectOptionNum += 2
   127  	o(opts)
   128  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   129  
   130  }
   131  
   132  // TestSelectOptionsOther tests other SelectOptions.
   133  func TestSelectOptionsOther(t *testing.T) {
   134  	opts := &client.Options{}
   135  
   136  	var selectOptionNum int
   137  
   138  	o := client.WithCircuitBreakerName("polaris")
   139  	selectOptionNum++
   140  	o(opts)
   141  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   142  
   143  	client.WithNamespace("development")(opts)
   144  	selectOptionNum++
   145  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   146  
   147  	client.WithEnvKey("env-key")(opts)
   148  	selectOptionNum++
   149  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   150  
   151  	client.WithKey("hash key")(opts)
   152  	selectOptionNum++
   153  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   154  
   155  	client.WithReplicas(100)(opts)
   156  	selectOptionNum++
   157  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   158  
   159  	client.WithDisableServiceRouter()(opts)
   160  	selectOptionNum++
   161  	require.Equal(t, selectOptionNum, len(opts.SelectOptions))
   162  	require.True(t, opts.DisableServiceRouter)
   163  
   164  	client.WithDisableFilter()(opts)
   165  	require.Equal(t, true, opts.DisableFilter)
   166  
   167  }
   168  
   169  func TestOptions(t *testing.T) {
   170  
   171  	opts := &client.Options{}
   172  	transportOpts := &transport.RoundTripOptions{}
   173  
   174  	// WithServiceName sets service name of backend service
   175  	o := client.WithServiceName("trpc.test.helloworld")
   176  	o(opts)
   177  	require.Equal(t, "trpc.test.helloworld", opts.ServiceName)
   178  
   179  	// WithTarget sets target address
   180  	o = client.WithTarget("ip://0.0.0.0:8080")
   181  	o(opts)
   182  	require.Equal(t, "ip://0.0.0.0:8080", opts.Target)
   183  
   184  	// WithNetwork sets network of backend service: tcp or udp, tcp by default
   185  	o = client.WithNetwork("tcp")
   186  	o(opts)
   187  	for _, o := range opts.CallOptions {
   188  		o(transportOpts)
   189  	}
   190  	require.Equal(t, "tcp", transportOpts.Network)
   191  
   192  	// WithTimeout sets timeout of dialing backend, 1s by default.
   193  	o = client.WithTimeout(time.Second)
   194  	o(opts)
   195  	for _, o := range opts.CallOptions {
   196  		o(transportOpts)
   197  	}
   198  	require.Equal(t, time.Second, opts.Timeout)
   199  
   200  	// WithTransport replaces client transport plugin
   201  	o = client.WithTransport(transport.DefaultClientTransport)
   202  	o(opts)
   203  	require.Equal(t, transport.DefaultClientTransport, opts.Transport)
   204  
   205  	// WithStreamTransport replaces client stream transport plugin
   206  	o = client.WithStreamTransport(transport.DefaultClientStreamTransport)
   207  	o(opts)
   208  	require.Equal(t, transport.DefaultClientStreamTransport, opts.StreamTransport)
   209  
   210  	// WithProtocol sets protocol of backend service like trpc
   211  	o = client.WithProtocol("trpc")
   212  	o(opts)
   213  	for _, o := range opts.CallOptions {
   214  		o(transportOpts)
   215  	}
   216  	require.Equal(t, trpc.DefaultClientCodec, opts.Codec)
   217  	require.Equal(t, transport.DefaultClientTransport, opts.Transport)
   218  
   219  	o = client.WithProtocol("http")
   220  	o(opts)
   221  	for _, o := range opts.CallOptions {
   222  		o(transportOpts)
   223  	}
   224  	require.Equal(t, http.DefaultClientCodec, opts.Codec)
   225  	require.Equal(t, http.DefaultClientTransport, opts.Transport)
   226  
   227  	o = client.WithSerializationType(codec.SerializationTypePB)
   228  	o(opts)
   229  	require.Equal(t, codec.SerializationTypePB, opts.SerializationType)
   230  
   231  	o = client.WithCompressType(codec.CompressTypeGzip)
   232  	o(opts)
   233  	require.Equal(t, codec.CompressTypeGzip, opts.CompressType)
   234  
   235  	o = client.WithClientStreamQueueSize(1024)
   236  	o(opts)
   237  	require.Equal(t, 1024, opts.ClientStreamQueueSize)
   238  
   239  	o = client.WithMaxWindowSize(1024)
   240  	o(opts)
   241  	require.Equal(t, uint32(1024), opts.MaxWindowSize)
   242  
   243  	o = client.WithMultiplexed(true)
   244  	o(opts)
   245  	require.Equal(t, true, opts.EnableMultiplexed)
   246  
   247  	// WithFilter appends a client filter to client filter chain.
   248  	o = client.WithFilter(filter.NoopClientFilter)
   249  	o(opts)
   250  	require.Equal(t, 1, len(opts.Filters))
   251  
   252  	// WithFilters appends multiple client filters to client filter chain.
   253  	o = client.WithFilters([]filter.ClientFilter{filter.NoopClientFilter})
   254  	o(opts)
   255  	require.Equal(t, 2, len(opts.Filters))
   256  
   257  	// WithPool sets custom conn pool
   258  	opt := []connpool.Option{
   259  		connpool.WithIdleTimeout(time.Duration(10) * time.Second),
   260  	}
   261  	pool := connpool.NewConnectionPool(opt...)
   262  	o = client.WithPool(pool)
   263  	o(opts)
   264  	for _, o := range opts.CallOptions {
   265  		o(transportOpts)
   266  	}
   267  	require.Equal(t, pool, transportOpts.Pool)
   268  }
   269  
   270  func TestDataOptions(t *testing.T) {
   271  	opts := &client.Options{}
   272  
   273  	// WithReqHead sets req head
   274  	o := client.WithReqHead(nil)
   275  	o(opts)
   276  	require.Equal(t, nil, opts.ReqHead)
   277  
   278  	// WithRspHead sets rsp head
   279  	o = client.WithRspHead(nil)
   280  	o(opts)
   281  	require.Equal(t, nil, opts.RspHead)
   282  
   283  	// WithSelectorNode records selected node
   284  	node := &registry.Node{}
   285  	o = client.WithSelectorNode(node)
   286  	o(opts)
   287  	require.Equal(t, node, opts.Node.Node)
   288  
   289  	o = client.WithCurrentSerializationType(1)
   290  	o(opts)
   291  	require.Equal(t, 1, opts.CurrentSerializationType)
   292  
   293  	o = client.WithCurrentCompressType(1)
   294  	o(opts)
   295  	require.Equal(t, 1, opts.CurrentCompressType)
   296  
   297  	o = client.WithMetaData("key", []byte("value"))
   298  	o(opts)
   299  	require.Equal(t, []byte("value"), opts.MetaData["key"])
   300  }
   301  
   302  // TestWithMultiplexedPool tests WithMultiplexedPool.
   303  func TestWithMultiplexedPool(t *testing.T) {
   304  	opts := &client.Options{}
   305  	roundTripOptions := &transport.RoundTripOptions{}
   306  	m := multiplexed.New(multiplexed.WithConnectNumber(8))
   307  	o := client.WithMultiplexedPool(m)
   308  	o(opts)
   309  	require.True(t, opts.EnableMultiplexed)
   310  	for _, o := range opts.CallOptions {
   311  		o(roundTripOptions)
   312  	}
   313  	require.Equal(t, m, roundTripOptions.Multiplexed)
   314  }
   315  
   316  func TestWithOptionsImmutable(t *testing.T) {
   317  	ctx := context.Background()
   318  	require.False(t, client.IsOptionsImmutable(ctx))
   319  
   320  	newCtx := client.WithOptionsImmutable(ctx)
   321  	require.True(t, client.IsOptionsImmutable(newCtx))
   322  }
   323  
   324  func TestWithLocalAddrOption(t *testing.T) {
   325  	opts := &client.Options{}
   326  	localAddr := "127.0.0.1:8080"
   327  	o := client.WithLocalAddr("127.0.0.1:8080")
   328  	o(opts)
   329  	roundTripOptions := &transport.RoundTripOptions{}
   330  	for _, o := range opts.CallOptions {
   331  		o(roundTripOptions)
   332  	}
   333  	require.Equal(t, roundTripOptions.LocalAddr, localAddr)
   334  }
   335  
   336  func TestWithDialTimeoutOption(t *testing.T) {
   337  	opts := &client.Options{}
   338  	timeout := time.Second
   339  	o := client.WithDialTimeout(timeout)
   340  	o(opts)
   341  	roundTripOptions := &transport.RoundTripOptions{}
   342  	for _, o := range opts.CallOptions {
   343  		o(roundTripOptions)
   344  	}
   345  	require.Equal(t, roundTripOptions.DialTimeout, timeout)
   346  }
   347  
   348  func TestWithNamedFilter(t *testing.T) {
   349  	var (
   350  		filterNames []string
   351  		filters     filter.ClientChain
   352  
   353  		cf = func(
   354  			ctx context.Context,
   355  			req, rsp interface{},
   356  			next filter.ClientHandleFunc) error {
   357  			return next(ctx, req, rsp)
   358  		}
   359  	)
   360  	for i := 0; i < 10; i++ {
   361  		filterNames = append(filterNames, fmt.Sprintf("filter-%d", i))
   362  		filters = append(filters, cf)
   363  	}
   364  
   365  	var os []client.Option
   366  	for i := range filters {
   367  		os = append(os, client.WithNamedFilter(filterNames[i], filters[i]))
   368  	}
   369  
   370  	options := &client.Options{}
   371  	for _, o := range os {
   372  		o(options)
   373  	}
   374  	require.Equal(t, filterNames, options.FilterNames)
   375  	require.Equal(t, len(filters), len(options.Filters))
   376  	for i := range filters {
   377  		require.Equal(
   378  			t,
   379  			runtime.FuncForPC(reflect.ValueOf(filters[i]).Pointer()).Name(),
   380  			runtime.FuncForPC(reflect.ValueOf(options.Filters[i]).Pointer()).Name(),
   381  		)
   382  	}
   383  }