trpc.group/trpc-go/trpc-go@v1.0.3/client/options_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package client_test 15 16 import ( 17 "context" 18 "fmt" 19 "reflect" 20 "runtime" 21 "testing" 22 "time" 23 24 "github.com/stretchr/testify/require" 25 26 trpc "trpc.group/trpc-go/trpc-go" 27 "trpc.group/trpc-go/trpc-go/client" 28 "trpc.group/trpc-go/trpc-go/codec" 29 "trpc.group/trpc-go/trpc-go/filter" 30 "trpc.group/trpc-go/trpc-go/http" 31 "trpc.group/trpc-go/trpc-go/naming/registry" 32 "trpc.group/trpc-go/trpc-go/pool/connpool" 33 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 34 "trpc.group/trpc-go/trpc-go/transport" 35 ) 36 37 func TestSelectOptions(t *testing.T) { 38 opts := &client.Options{} 39 40 var selectOptionNum int 41 var callOptionNum int 42 43 // WithCallerServiceName sets service name of the service itself 44 o := client.WithCallerServiceName("trpc.test.helloworld1") 45 selectOptionNum++ 46 o(opts) 47 require.Equal(t, "trpc.test.helloworld1", opts.CallerServiceName) 48 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 49 50 o = client.WithCallerNamespace("Production") 51 selectOptionNum++ 52 o(opts) 53 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 54 55 o = client.WithCallerEnvName("test") 56 selectOptionNum++ 57 o(opts) 58 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 59 60 o = client.WithCalleeEnvName("test") 61 selectOptionNum++ 62 o(opts) 63 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 64 65 o = client.WithCallerSetName("set") 66 selectOptionNum++ 67 o(opts) 68 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 69 70 o = client.WithCalleeSetName("set") 71 selectOptionNum++ 72 o(opts) 73 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 74 75 o = client.WithCalleeMethod("func") 76 o(opts) 77 require.Equal(t, "func", opts.CalleeMethod) 78 79 o = client.WithCallerMetadata("tag", "data") 80 selectOptionNum++ 81 o(opts) 82 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 83 84 o = client.WithCalleeMetadata("tag", "data") 85 selectOptionNum++ 86 o(opts) 87 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 88 89 o = client.WithPassword("passwd") 90 callOptionNum++ 91 o(opts) 92 require.Equal(t, callOptionNum, len(opts.CallOptions)) 93 94 o = client.WithConnectionMode(transport.Connected) 95 callOptionNum++ 96 o(opts) 97 require.Equal(t, callOptionNum, len(opts.CallOptions)) 98 99 o = client.WithSendOnly() 100 callOptionNum++ 101 o(opts) 102 require.Equal(t, opts.CallType, codec.SendOnly) 103 require.Equal(t, callOptionNum, len(opts.CallOptions)) 104 105 o = client.WithTLS("client.cert", "client.key", "ca.pem", "servername") 106 callOptionNum++ 107 o(opts) 108 require.Equal(t, callOptionNum, len(opts.CallOptions)) 109 110 o = client.WithDisableConnectionPool() 111 callOptionNum++ 112 o(opts) 113 require.Equal(t, callOptionNum, len(opts.CallOptions)) 114 115 o = client.WithDiscoveryName("polaris") 116 selectOptionNum++ 117 o(opts) 118 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 119 120 o = client.WithServiceRouterName("polaris") 121 selectOptionNum++ 122 o(opts) 123 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 124 125 o = client.WithBalancerName("polaris") 126 selectOptionNum += 2 127 o(opts) 128 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 129 130 } 131 132 // TestSelectOptionsOther tests other SelectOptions. 133 func TestSelectOptionsOther(t *testing.T) { 134 opts := &client.Options{} 135 136 var selectOptionNum int 137 138 o := client.WithCircuitBreakerName("polaris") 139 selectOptionNum++ 140 o(opts) 141 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 142 143 client.WithNamespace("development")(opts) 144 selectOptionNum++ 145 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 146 147 client.WithEnvKey("env-key")(opts) 148 selectOptionNum++ 149 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 150 151 client.WithKey("hash key")(opts) 152 selectOptionNum++ 153 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 154 155 client.WithReplicas(100)(opts) 156 selectOptionNum++ 157 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 158 159 client.WithDisableServiceRouter()(opts) 160 selectOptionNum++ 161 require.Equal(t, selectOptionNum, len(opts.SelectOptions)) 162 require.True(t, opts.DisableServiceRouter) 163 164 client.WithDisableFilter()(opts) 165 require.Equal(t, true, opts.DisableFilter) 166 167 } 168 169 func TestOptions(t *testing.T) { 170 171 opts := &client.Options{} 172 transportOpts := &transport.RoundTripOptions{} 173 174 // WithServiceName sets service name of backend service 175 o := client.WithServiceName("trpc.test.helloworld") 176 o(opts) 177 require.Equal(t, "trpc.test.helloworld", opts.ServiceName) 178 179 // WithTarget sets target address 180 o = client.WithTarget("ip://0.0.0.0:8080") 181 o(opts) 182 require.Equal(t, "ip://0.0.0.0:8080", opts.Target) 183 184 // WithNetwork sets network of backend service: tcp or udp, tcp by default 185 o = client.WithNetwork("tcp") 186 o(opts) 187 for _, o := range opts.CallOptions { 188 o(transportOpts) 189 } 190 require.Equal(t, "tcp", transportOpts.Network) 191 192 // WithTimeout sets timeout of dialing backend, 1s by default. 193 o = client.WithTimeout(time.Second) 194 o(opts) 195 for _, o := range opts.CallOptions { 196 o(transportOpts) 197 } 198 require.Equal(t, time.Second, opts.Timeout) 199 200 // WithTransport replaces client transport plugin 201 o = client.WithTransport(transport.DefaultClientTransport) 202 o(opts) 203 require.Equal(t, transport.DefaultClientTransport, opts.Transport) 204 205 // WithStreamTransport replaces client stream transport plugin 206 o = client.WithStreamTransport(transport.DefaultClientStreamTransport) 207 o(opts) 208 require.Equal(t, transport.DefaultClientStreamTransport, opts.StreamTransport) 209 210 // WithProtocol sets protocol of backend service like trpc 211 o = client.WithProtocol("trpc") 212 o(opts) 213 for _, o := range opts.CallOptions { 214 o(transportOpts) 215 } 216 require.Equal(t, trpc.DefaultClientCodec, opts.Codec) 217 require.Equal(t, transport.DefaultClientTransport, opts.Transport) 218 219 o = client.WithProtocol("http") 220 o(opts) 221 for _, o := range opts.CallOptions { 222 o(transportOpts) 223 } 224 require.Equal(t, http.DefaultClientCodec, opts.Codec) 225 require.Equal(t, http.DefaultClientTransport, opts.Transport) 226 227 o = client.WithSerializationType(codec.SerializationTypePB) 228 o(opts) 229 require.Equal(t, codec.SerializationTypePB, opts.SerializationType) 230 231 o = client.WithCompressType(codec.CompressTypeGzip) 232 o(opts) 233 require.Equal(t, codec.CompressTypeGzip, opts.CompressType) 234 235 o = client.WithClientStreamQueueSize(1024) 236 o(opts) 237 require.Equal(t, 1024, opts.ClientStreamQueueSize) 238 239 o = client.WithMaxWindowSize(1024) 240 o(opts) 241 require.Equal(t, uint32(1024), opts.MaxWindowSize) 242 243 o = client.WithMultiplexed(true) 244 o(opts) 245 require.Equal(t, true, opts.EnableMultiplexed) 246 247 // WithFilter appends a client filter to client filter chain. 248 o = client.WithFilter(filter.NoopClientFilter) 249 o(opts) 250 require.Equal(t, 1, len(opts.Filters)) 251 252 // WithFilters appends multiple client filters to client filter chain. 253 o = client.WithFilters([]filter.ClientFilter{filter.NoopClientFilter}) 254 o(opts) 255 require.Equal(t, 2, len(opts.Filters)) 256 257 // WithPool sets custom conn pool 258 opt := []connpool.Option{ 259 connpool.WithIdleTimeout(time.Duration(10) * time.Second), 260 } 261 pool := connpool.NewConnectionPool(opt...) 262 o = client.WithPool(pool) 263 o(opts) 264 for _, o := range opts.CallOptions { 265 o(transportOpts) 266 } 267 require.Equal(t, pool, transportOpts.Pool) 268 } 269 270 func TestDataOptions(t *testing.T) { 271 opts := &client.Options{} 272 273 // WithReqHead sets req head 274 o := client.WithReqHead(nil) 275 o(opts) 276 require.Equal(t, nil, opts.ReqHead) 277 278 // WithRspHead sets rsp head 279 o = client.WithRspHead(nil) 280 o(opts) 281 require.Equal(t, nil, opts.RspHead) 282 283 // WithSelectorNode records selected node 284 node := ®istry.Node{} 285 o = client.WithSelectorNode(node) 286 o(opts) 287 require.Equal(t, node, opts.Node.Node) 288 289 o = client.WithCurrentSerializationType(1) 290 o(opts) 291 require.Equal(t, 1, opts.CurrentSerializationType) 292 293 o = client.WithCurrentCompressType(1) 294 o(opts) 295 require.Equal(t, 1, opts.CurrentCompressType) 296 297 o = client.WithMetaData("key", []byte("value")) 298 o(opts) 299 require.Equal(t, []byte("value"), opts.MetaData["key"]) 300 } 301 302 // TestWithMultiplexedPool tests WithMultiplexedPool. 303 func TestWithMultiplexedPool(t *testing.T) { 304 opts := &client.Options{} 305 roundTripOptions := &transport.RoundTripOptions{} 306 m := multiplexed.New(multiplexed.WithConnectNumber(8)) 307 o := client.WithMultiplexedPool(m) 308 o(opts) 309 require.True(t, opts.EnableMultiplexed) 310 for _, o := range opts.CallOptions { 311 o(roundTripOptions) 312 } 313 require.Equal(t, m, roundTripOptions.Multiplexed) 314 } 315 316 func TestWithOptionsImmutable(t *testing.T) { 317 ctx := context.Background() 318 require.False(t, client.IsOptionsImmutable(ctx)) 319 320 newCtx := client.WithOptionsImmutable(ctx) 321 require.True(t, client.IsOptionsImmutable(newCtx)) 322 } 323 324 func TestWithLocalAddrOption(t *testing.T) { 325 opts := &client.Options{} 326 localAddr := "127.0.0.1:8080" 327 o := client.WithLocalAddr("127.0.0.1:8080") 328 o(opts) 329 roundTripOptions := &transport.RoundTripOptions{} 330 for _, o := range opts.CallOptions { 331 o(roundTripOptions) 332 } 333 require.Equal(t, roundTripOptions.LocalAddr, localAddr) 334 } 335 336 func TestWithDialTimeoutOption(t *testing.T) { 337 opts := &client.Options{} 338 timeout := time.Second 339 o := client.WithDialTimeout(timeout) 340 o(opts) 341 roundTripOptions := &transport.RoundTripOptions{} 342 for _, o := range opts.CallOptions { 343 o(roundTripOptions) 344 } 345 require.Equal(t, roundTripOptions.DialTimeout, timeout) 346 } 347 348 func TestWithNamedFilter(t *testing.T) { 349 var ( 350 filterNames []string 351 filters filter.ClientChain 352 353 cf = func( 354 ctx context.Context, 355 req, rsp interface{}, 356 next filter.ClientHandleFunc) error { 357 return next(ctx, req, rsp) 358 } 359 ) 360 for i := 0; i < 10; i++ { 361 filterNames = append(filterNames, fmt.Sprintf("filter-%d", i)) 362 filters = append(filters, cf) 363 } 364 365 var os []client.Option 366 for i := range filters { 367 os = append(os, client.WithNamedFilter(filterNames[i], filters[i])) 368 } 369 370 options := &client.Options{} 371 for _, o := range os { 372 o(options) 373 } 374 require.Equal(t, filterNames, options.FilterNames) 375 require.Equal(t, len(filters), len(options.Filters)) 376 for i := range filters { 377 require.Equal( 378 t, 379 runtime.FuncForPC(reflect.ValueOf(filters[i]).Pointer()).Name(), 380 runtime.FuncForPC(reflect.ValueOf(options.Filters[i]).Pointer()).Name(), 381 ) 382 } 383 }