github.com/cloudwego/kitex@v0.9.0/server/server_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package server 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "os" 24 "reflect" 25 "strings" 26 "sync" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/bytedance/gopkg/cloud/metainfo" 32 "github.com/cloudwego/localsession" 33 "github.com/cloudwego/localsession/backup" 34 "github.com/golang/mock/gomock" 35 36 "github.com/cloudwego/kitex/internal/mocks" 37 mockslimiter "github.com/cloudwego/kitex/internal/mocks/limiter" 38 internal_server "github.com/cloudwego/kitex/internal/server" 39 "github.com/cloudwego/kitex/internal/test" 40 "github.com/cloudwego/kitex/pkg/endpoint" 41 "github.com/cloudwego/kitex/pkg/limit" 42 "github.com/cloudwego/kitex/pkg/limiter" 43 "github.com/cloudwego/kitex/pkg/registry" 44 "github.com/cloudwego/kitex/pkg/remote" 45 "github.com/cloudwego/kitex/pkg/remote/bound" 46 "github.com/cloudwego/kitex/pkg/remote/trans" 47 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" 48 "github.com/cloudwego/kitex/pkg/rpcinfo" 49 "github.com/cloudwego/kitex/pkg/serviceinfo" 50 "github.com/cloudwego/kitex/pkg/stats" 51 "github.com/cloudwego/kitex/pkg/transmeta" 52 "github.com/cloudwego/kitex/pkg/utils" 53 "github.com/cloudwego/kitex/transport" 54 ) 55 56 var ( 57 svcInfo = mocks.ServiceInfo() 58 svcSearchMap = map[string]*serviceinfo.ServiceInfo{ 59 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, 60 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, 61 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, 62 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, 63 mocks.MockMethod: svcInfo, 64 mocks.MockExceptionMethod: svcInfo, 65 mocks.MockErrorMethod: svcInfo, 66 mocks.MockOnewayMethod: svcInfo, 67 } 68 ) 69 70 func TestServerRun(t *testing.T) { 71 var opts []Option 72 opts = append(opts, WithMetaHandler(noopMetahandler{})) 73 svr := NewServer(opts...) 74 75 time.AfterFunc(time.Millisecond*500, func() { 76 err := svr.Stop() 77 test.Assert(t, err == nil, err) 78 }) 79 80 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 81 test.Assert(t, err == nil) 82 83 var runHook int32 84 var shutdownHook int32 85 RegisterStartHook(func() { 86 atomic.AddInt32(&runHook, 1) 87 }) 88 RegisterShutdownHook(func() { 89 atomic.AddInt32(&shutdownHook, 1) 90 }) 91 err = svr.Run() 92 test.Assert(t, err == nil, err) 93 94 test.Assert(t, atomic.LoadInt32(&runHook) == 1) 95 test.Assert(t, atomic.LoadInt32(&shutdownHook) == 1) 96 } 97 98 func TestReusePortServerRun(t *testing.T) { 99 hostPort := test.GetLocalAddress() 100 addr, _ := net.ResolveTCPAddr("tcp", hostPort) 101 var opts []Option 102 opts = append(opts, WithReusePort(true)) 103 opts = append(opts, WithServiceAddr(addr), WithExitWaitTime(time.Microsecond*10)) 104 105 var wg sync.WaitGroup 106 for i := 0; i < 8; i++ { 107 wg.Add(2) 108 go func() { 109 defer wg.Done() 110 svr := NewServer(opts...) 111 time.AfterFunc(time.Millisecond*100, func() { 112 defer wg.Done() 113 err := svr.Stop() 114 test.Assert(t, err == nil, err) 115 }) 116 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 117 test.Assert(t, err == nil) 118 err = svr.Run() 119 test.Assert(t, err == nil, err) 120 }() 121 } 122 wg.Wait() 123 } 124 125 func TestInitOrResetRPCInfo(t *testing.T) { 126 var opts []Option 127 rwTimeout := time.Millisecond 128 opts = append(opts, WithReadWriteTimeout(rwTimeout)) 129 svr := &server{ 130 opt: internal_server.NewOptions(opts), 131 svcs: newServices(), 132 } 133 svr.init() 134 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 135 test.Assert(t, err == nil) 136 137 remoteAddr := utils.NewNetAddr("tcp", "to") 138 conn := new(mocks.Conn) 139 conn.RemoteAddrFunc = func() (r net.Addr) { 140 return remoteAddr 141 } 142 rpcInfoInitFunc := svr.initOrResetRPCInfoFunc() 143 ri := rpcInfoInitFunc(nil, conn.RemoteAddr()) 144 test.Assert(t, ri != nil) 145 test.Assert(t, ri.From().Address().String() == remoteAddr.String()) 146 test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout) 147 148 // modify rpcinfo 149 fi := rpcinfo.AsMutableEndpointInfo(ri.From()) 150 fi.SetServiceName("mock service") 151 fi.SetMethod("mock method") 152 fi.SetTag("key", "value") 153 154 ti := rpcinfo.AsMutableEndpointInfo(ri.To()) 155 ti.SetServiceName("mock service") 156 ti.SetMethod("mock method") 157 ti.SetTag("key", "value") 158 159 if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { 160 setter.SetSeqID(123) 161 setter.SetPackageName("mock package") 162 setter.SetServiceName("mock service") 163 setter.SetMethodName("mock method") 164 } 165 166 mc := rpcinfo.AsMutableRPCConfig(ri.Config()) 167 mc.SetTransportProtocol(transport.TTHeader) 168 mc.SetConnectTimeout(10 * time.Second) 169 mc.SetRPCTimeout(20 * time.Second) 170 mc.SetInteractionMode(rpcinfo.Streaming) 171 mc.SetIOBufferSize(1024) 172 mc.SetReadWriteTimeout(30 * time.Second) 173 174 rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) 175 rpcStats.SetRecvSize(1024) 176 rpcStats.SetSendSize(1024) 177 rpcStats.SetPanicked(errors.New("panic")) 178 rpcStats.SetError(errors.New("err")) 179 rpcStats.SetLevel(stats.LevelDetailed) 180 181 // check setting 182 test.Assert(t, ri.From().ServiceName() == "mock service") 183 test.Assert(t, ri.From().Method() == "mock method") 184 value, exist := ri.From().Tag("key") 185 test.Assert(t, exist && value == "value") 186 187 test.Assert(t, ri.To().ServiceName() == "mock service") 188 test.Assert(t, ri.To().Method() == "mock method") 189 value, exist = ri.To().Tag("key") 190 test.Assert(t, exist && value == "value") 191 192 test.Assert(t, ri.Invocation().SeqID() == 123) 193 test.Assert(t, ri.Invocation().PackageName() == "mock package") 194 test.Assert(t, ri.Invocation().ServiceName() == "mock service") 195 test.Assert(t, ri.Invocation().MethodName() == "mock method") 196 197 test.Assert(t, ri.Config().TransportProtocol() == transport.TTHeader) 198 test.Assert(t, ri.Config().ConnectTimeout() == 10*time.Second) 199 test.Assert(t, ri.Config().RPCTimeout() == 20*time.Second) 200 test.Assert(t, ri.Config().InteractionMode() == rpcinfo.Streaming) 201 test.Assert(t, ri.Config().IOBufferSize() == 1024) 202 test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout) 203 204 test.Assert(t, ri.Stats().RecvSize() == 1024) 205 test.Assert(t, ri.Stats().SendSize() == 1024) 206 hasPanicked, _ := ri.Stats().Panicked() 207 test.Assert(t, hasPanicked) 208 test.Assert(t, ri.Stats().Error() != nil) 209 test.Assert(t, ri.Stats().Level() == stats.LevelDetailed) 210 211 // test reset 212 pOld := reflect.ValueOf(ri).Pointer() 213 ri = rpcInfoInitFunc(ri, conn.RemoteAddr()) 214 pNew := reflect.ValueOf(ri).Pointer() 215 test.Assert(t, pOld == pNew, pOld, pNew) 216 test.Assert(t, ri.From().Address().String() == remoteAddr.String()) 217 test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout) 218 219 test.Assert(t, ri.From().ServiceName() == "") 220 test.Assert(t, ri.From().Method() == "") 221 _, exist = ri.From().Tag("key") 222 test.Assert(t, !exist) 223 224 test.Assert(t, ri.To().ServiceName() == "") 225 test.Assert(t, ri.To().Method() == "") 226 _, exist = ri.To().Tag("key") 227 test.Assert(t, !exist) 228 229 test.Assert(t, ri.Invocation().SeqID() == 0) 230 test.Assert(t, ri.Invocation().PackageName() == "") 231 test.Assert(t, ri.Invocation().ServiceName() == "") 232 test.Assert(t, ri.Invocation().MethodName() == "") 233 234 test.Assert(t, ri.Config().TransportProtocol() == 0) 235 test.Assert(t, ri.Config().ConnectTimeout() == 50*time.Millisecond) 236 test.Assert(t, ri.Config().RPCTimeout() == 0) 237 test.Assert(t, ri.Config().InteractionMode() == rpcinfo.PingPong) 238 test.Assert(t, ri.Config().IOBufferSize() == 4096) 239 test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout) 240 241 test.Assert(t, ri.Stats().RecvSize() == 0) 242 test.Assert(t, ri.Stats().SendSize() == 0) 243 _, panicked := ri.Stats().Panicked() 244 test.Assert(t, panicked == nil) 245 test.Assert(t, ri.Stats().Error() == nil) 246 test.Assert(t, ri.Stats().Level() == 0) 247 248 // test reset after rpcInfoPool is disabled 249 t.Run("reset with pool disabled", func(t *testing.T) { 250 backupState := rpcinfo.PoolEnabled() 251 defer rpcinfo.EnablePool(backupState) 252 rpcinfo.EnablePool(false) 253 254 riNew := rpcInfoInitFunc(ri, conn.RemoteAddr()) 255 pOld, pNew := reflect.ValueOf(ri).Pointer(), reflect.ValueOf(riNew).Pointer() 256 test.Assert(t, pOld != pNew, pOld, pNew) 257 }) 258 } 259 260 func TestServiceRegisterFailed(t *testing.T) { 261 mockRegErr := errors.New("mock register error") 262 var rCount int 263 var drCount int 264 mockRegistry := MockRegistry{ 265 RegisterFunc: func(info *registry.Info) error { 266 rCount++ 267 return mockRegErr 268 }, 269 DeregisterFunc: func(info *registry.Info) error { 270 drCount++ 271 return nil 272 }, 273 } 274 var opts []Option 275 opts = append(opts, WithRegistry(mockRegistry)) 276 svr := NewServer(opts...) 277 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 278 test.Assert(t, err == nil) 279 280 err = svr.Run() 281 test.Assert(t, err != nil) 282 test.Assert(t, strings.Contains(err.Error(), mockRegErr.Error())) 283 test.Assert(t, drCount == 1) 284 } 285 286 func TestServiceDeregisterFailed(t *testing.T) { 287 mockDeregErr := errors.New("mock deregister error") 288 var rCount int 289 var drCount int 290 mockRegistry := MockRegistry{ 291 RegisterFunc: func(info *registry.Info) error { 292 rCount++ 293 return nil 294 }, 295 DeregisterFunc: func(info *registry.Info) error { 296 drCount++ 297 return mockDeregErr 298 }, 299 } 300 var opts []Option 301 opts = append(opts, WithRegistry(mockRegistry)) 302 svr := NewServer(opts...) 303 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 304 test.Assert(t, err == nil) 305 306 time.AfterFunc(2000*time.Millisecond, func() { 307 err := svr.Stop() 308 test.Assert(t, strings.Contains(err.Error(), mockDeregErr.Error())) 309 }) 310 err = svr.Run() 311 test.Assert(t, err == nil, err) 312 test.Assert(t, rCount == 1) 313 } 314 315 func TestServiceRegistryInfo(t *testing.T) { 316 registryInfo := ®istry.Info{ 317 Weight: 100, 318 Tags: map[string]string{"aa": "bb"}, 319 } 320 checkInfo := func(info *registry.Info) { 321 test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) 322 test.Assert(t, info.Weight == registryInfo.Weight, info.Addr) 323 test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) 324 test.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) 325 test.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) 326 } 327 var rCount int 328 var drCount int 329 mockRegistry := MockRegistry{ 330 RegisterFunc: func(info *registry.Info) error { 331 checkInfo(info) 332 rCount++ 333 return nil 334 }, 335 DeregisterFunc: func(info *registry.Info) error { 336 checkInfo(info) 337 drCount++ 338 return nil 339 }, 340 } 341 var opts []Option 342 opts = append(opts, WithRegistry(mockRegistry)) 343 opts = append(opts, WithRegistryInfo(registryInfo)) 344 svr := NewServer(opts...) 345 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 346 test.Assert(t, err == nil) 347 348 time.AfterFunc(2000*time.Millisecond, func() { 349 err := svr.Stop() 350 test.Assert(t, err == nil, err) 351 }) 352 err = svr.Run() 353 test.Assert(t, err == nil, err) 354 test.Assert(t, rCount == 1) 355 test.Assert(t, drCount == 1) 356 } 357 358 func TestServiceRegistryNoInitInfo(t *testing.T) { 359 checkInfo := func(info *registry.Info) { 360 test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) 361 test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) 362 } 363 var rCount int 364 var drCount int 365 mockRegistry := MockRegistry{ 366 RegisterFunc: func(info *registry.Info) error { 367 checkInfo(info) 368 rCount++ 369 return nil 370 }, 371 DeregisterFunc: func(info *registry.Info) error { 372 checkInfo(info) 373 drCount++ 374 return nil 375 }, 376 } 377 var opts []Option 378 opts = append(opts, WithRegistry(mockRegistry)) 379 svr := NewServer(opts...) 380 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 381 test.Assert(t, err == nil) 382 383 time.AfterFunc(2000*time.Millisecond, func() { 384 err := svr.Stop() 385 test.Assert(t, err == nil, err) 386 }) 387 err = svr.Run() 388 test.Assert(t, err == nil, err) 389 test.Assert(t, rCount == 1) 390 test.Assert(t, drCount == 1) 391 } 392 393 // TestServiceRegistryInfoWithNilTags is to check the Tags val. If Tags of RegistryInfo is nil, 394 // the Tags of ServerBasicInfo will be assigned to Tags of RegistryInfo 395 func TestServiceRegistryInfoWithNilTags(t *testing.T) { 396 registryInfo := ®istry.Info{ 397 Weight: 100, 398 } 399 checkInfo := func(info *registry.Info) { 400 test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec) 401 test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr) 402 test.Assert(t, info.Weight == registryInfo.Weight, info.Weight) 403 test.Assert(t, info.Tags["aa"] == "bb", info.Tags) 404 } 405 var rCount int 406 var drCount int 407 mockRegistry := MockRegistry{ 408 RegisterFunc: func(info *registry.Info) error { 409 checkInfo(info) 410 rCount++ 411 return nil 412 }, 413 DeregisterFunc: func(info *registry.Info) error { 414 checkInfo(info) 415 drCount++ 416 return nil 417 }, 418 } 419 var opts []Option 420 opts = append(opts, WithRegistry(mockRegistry)) 421 opts = append(opts, WithRegistryInfo(registryInfo)) 422 opts = append(opts, WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{ 423 Tags: map[string]string{"aa": "bb"}, 424 })) 425 svr := NewServer(opts...) 426 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 427 test.Assert(t, err == nil) 428 429 time.AfterFunc(2000*time.Millisecond, func() { 430 err := svr.Stop() 431 test.Assert(t, err == nil, err) 432 }) 433 err = svr.Run() 434 test.Assert(t, err == nil, err) 435 test.Assert(t, rCount == 1) 436 test.Assert(t, drCount == 1) 437 } 438 439 func TestGRPCServerMultipleServices(t *testing.T) { 440 var opts []Option 441 opts = append(opts, withGRPCTransport()) 442 svr := NewServer(opts...) 443 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 444 test.Assert(t, err == nil) 445 err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler()) 446 test.Assert(t, err == nil) 447 test.DeepEqual(t, svr.GetServiceInfos()[mocks.MockMethod], mocks.ServiceInfo()) 448 test.DeepEqual(t, svr.GetServiceInfos()[mocks.Mock2Method], mocks.Service2Info()) 449 time.AfterFunc(1000*time.Millisecond, func() { 450 err := svr.Stop() 451 test.Assert(t, err == nil, err) 452 }) 453 err = svr.Run() 454 test.Assert(t, err == nil, err) 455 } 456 457 func TestServerBoundHandler(t *testing.T) { 458 ctrl := gomock.NewController(t) 459 defer ctrl.Finish() 460 461 interval := time.Millisecond * 100 462 cases := []struct { 463 opts []Option 464 wantInbounds []remote.InboundHandler 465 wantOutbounds []remote.OutboundHandler 466 }{ 467 { 468 opts: []Option{ 469 WithLimit(&limit.Option{ 470 MaxConnections: 1000, 471 MaxQPS: 10000, 472 }), 473 WithMetaHandler(noopMetahandler{}), 474 }, 475 wantInbounds: []remote.InboundHandler{ 476 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), 477 bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, false), 478 }, 479 wantOutbounds: []remote.OutboundHandler{ 480 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), 481 }, 482 }, 483 { 484 opts: []Option{ 485 WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), 486 }, 487 wantInbounds: []remote.InboundHandler{ 488 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 489 bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), &limiter.DummyRateLimiter{}, nil, false), 490 }, 491 wantOutbounds: []remote.OutboundHandler{ 492 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 493 }, 494 }, 495 { 496 opts: []Option{ 497 WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), 498 }, 499 wantInbounds: []remote.InboundHandler{ 500 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 501 bound.NewServerLimiterHandler(&limiter.DummyConcurrencyLimiter{}, mockslimiter.NewMockRateLimiter(ctrl), nil, true), 502 }, 503 wantOutbounds: []remote.OutboundHandler{ 504 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 505 }, 506 }, 507 { 508 opts: []Option{ 509 WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), 510 WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), 511 }, 512 wantInbounds: []remote.InboundHandler{ 513 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 514 bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), 515 }, 516 wantOutbounds: []remote.OutboundHandler{ 517 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 518 }, 519 }, 520 { 521 opts: []Option{ 522 WithLimit(&limit.Option{ 523 MaxConnections: 1000, 524 MaxQPS: 10000, 525 }), 526 WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), 527 }, 528 wantInbounds: []remote.InboundHandler{ 529 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 530 bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), limiter.NewQPSLimiter(interval, 10000), nil, false), 531 }, 532 wantOutbounds: []remote.OutboundHandler{ 533 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 534 }, 535 }, 536 { 537 opts: []Option{ 538 WithLimit(&limit.Option{ 539 MaxConnections: 1000, 540 MaxQPS: 10000, 541 }), 542 WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), 543 WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), 544 }, 545 wantInbounds: []remote.InboundHandler{ 546 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 547 bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), 548 }, 549 wantOutbounds: []remote.OutboundHandler{ 550 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 551 }, 552 }, 553 { 554 opts: []Option{ 555 WithLimit(&limit.Option{ 556 MaxConnections: 1000, 557 MaxQPS: 10000, 558 }), 559 WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)), 560 WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)), 561 WithMuxTransport(), 562 }, 563 wantInbounds: []remote.InboundHandler{ 564 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 565 bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true), 566 }, 567 wantOutbounds: []remote.OutboundHandler{ 568 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 569 }, 570 }, 571 { 572 opts: []Option{ 573 WithLimit(&limit.Option{ 574 MaxConnections: 1000, 575 MaxQPS: 10000, 576 }), 577 WithMuxTransport(), 578 }, 579 wantInbounds: []remote.InboundHandler{ 580 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 581 bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, true), 582 }, 583 wantOutbounds: []remote.OutboundHandler{ 584 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}), 585 }, 586 }, 587 { 588 opts: []Option{ 589 WithMetaHandler(noopMetahandler{}), 590 }, 591 wantInbounds: []remote.InboundHandler{ 592 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), 593 }, 594 wantOutbounds: []remote.OutboundHandler{ 595 bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}), 596 }, 597 }, 598 } 599 for _, tcase := range cases { 600 opts := append(tcase.opts, WithExitWaitTime(time.Millisecond*10)) 601 svr := NewServer(opts...) 602 603 time.AfterFunc(100*time.Millisecond, func() { 604 err := svr.Stop() 605 test.Assert(t, err == nil, err) 606 }) 607 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 608 test.Assert(t, err == nil) 609 err = svr.Run() 610 test.Assert(t, err == nil, err) 611 612 iSvr := svr.(*server) 613 test.Assert(t, inboundDeepEqual(iSvr.opt.RemoteOpt.Inbounds, tcase.wantInbounds)) 614 test.Assert(t, reflect.DeepEqual(iSvr.opt.RemoteOpt.Outbounds, tcase.wantOutbounds)) 615 svr.Stop() 616 } 617 } 618 619 func TestInvokeHandlerWithContextBackup(t *testing.T) { 620 testInvokeHandlerWithSession(t, true, ":8888") 621 os.Setenv(localsession.SESSION_CONFIG_KEY, "true,100,1h") 622 testInvokeHandlerWithSession(t, false, ":8889") 623 } 624 625 func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { 626 callMethod := "mock" 627 var opts []Option 628 mwExec := false 629 opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { 630 return func(ctx context.Context, req, resp interface{}) (err error) { 631 mwExec = true 632 return next(ctx, req, resp) 633 } 634 })) 635 636 k1, v1 := "1", "1" 637 var backupHandler backup.BackupHandler 638 if !fail { 639 opts = append(opts, WithContextBackup(true, true)) 640 backupHandler = func(prev, cur context.Context) (context.Context, bool) { 641 v := prev.Value(k1) 642 if v != nil { 643 cur = context.WithValue(cur, k1, v) 644 return cur, true 645 } 646 return cur, true 647 } 648 } 649 650 opts = append(opts, WithCodec(&mockCodec{})) 651 transHdlrFact := &mockSvrTransHandlerFactory{} 652 exitCh := make(chan bool) 653 var ln net.Listener 654 transSvr := &mocks.MockTransServer{ 655 BootstrapServerFunc: func(net.Listener) error { 656 { // mock server call 657 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) 658 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 659 recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) 660 recvMsg.NewData(callMethod) 661 sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) 662 663 // inject kvs here 664 ctx = metainfo.WithPersistentValue(ctx, "a", "b") 665 ctx = context.WithValue(ctx, k1, v1) 666 667 _, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg) 668 test.Assert(t, err == nil, err) 669 } 670 <-exitCh 671 return nil 672 }, 673 ShutdownFunc: func() error { 674 if ln != nil { 675 ln.Close() 676 } 677 exitCh <- true 678 return nil 679 }, 680 CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { 681 var err error 682 ln, err = net.Listen("tcp", ad) 683 return ln, err 684 }, 685 } 686 opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr))) 687 opts = append(opts, WithTransHandlerFactory(transHdlrFact)) 688 689 svr := NewServer(opts...) 690 time.AfterFunc(100*time.Millisecond, func() { 691 err := svr.Stop() 692 test.Assert(t, err == nil, err) 693 }) 694 serviceHandler := false 695 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) { 696 serviceHandler = true 697 698 wg := sync.WaitGroup{} 699 wg.Add(1) 700 go func() { 701 defer wg.Done() 702 703 // miss context here 704 ctx := backup.RecoverCtxOnDemands(context.Background(), backupHandler) 705 706 if !fail { 707 b, _ := metainfo.GetPersistentValue(ctx, "a") 708 test.Assert(t, b == "b", "can't get metainfo") 709 test.Assert(t, ctx.Value(k1) == v1, "can't get v1") 710 } else { 711 _, ok := metainfo.GetPersistentValue(ctx, "a") 712 test.Assert(t, !ok, "can get metainfo") 713 test.Assert(t, ctx.Value(k1) != v1, "can get v1") 714 } 715 }() 716 wg.Wait() 717 718 return &mocks.MyResponse{Name: "mock"}, nil 719 })) 720 test.Assert(t, err == nil) 721 err = svr.Run() 722 test.Assert(t, err == nil, err) 723 test.Assert(t, mwExec) 724 test.Assert(t, serviceHandler) 725 } 726 727 func TestInvokeHandlerExec(t *testing.T) { 728 callMethod := "mock" 729 var opts []Option 730 mwExec := false 731 opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { 732 return func(ctx context.Context, req, resp interface{}) (err error) { 733 mwExec = true 734 return next(ctx, req, resp) 735 } 736 })) 737 opts = append(opts, WithCodec(&mockCodec{})) 738 transHdlrFact := &mockSvrTransHandlerFactory{} 739 exitCh := make(chan bool) 740 var ln net.Listener 741 transSvr := &mocks.MockTransServer{ 742 BootstrapServerFunc: func(net.Listener) error { 743 { // mock server call 744 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) 745 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 746 recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) 747 recvMsg.NewData(callMethod) 748 sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) 749 750 _, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg) 751 test.Assert(t, err == nil, err) 752 } 753 <-exitCh 754 return nil 755 }, 756 ShutdownFunc: func() error { 757 if ln != nil { 758 ln.Close() 759 } 760 exitCh <- true 761 return nil 762 }, 763 CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { 764 var err error 765 ln, err = net.Listen("tcp", ":8888") 766 return ln, err 767 }, 768 } 769 opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr))) 770 opts = append(opts, WithTransHandlerFactory(transHdlrFact)) 771 772 svr := NewServer(opts...) 773 time.AfterFunc(100*time.Millisecond, func() { 774 err := svr.Stop() 775 test.Assert(t, err == nil, err) 776 }) 777 serviceHandler := false 778 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) { 779 serviceHandler = true 780 return &mocks.MyResponse{Name: "mock"}, nil 781 })) 782 test.Assert(t, err == nil) 783 err = svr.Run() 784 test.Assert(t, err == nil, err) 785 test.Assert(t, mwExec) 786 test.Assert(t, serviceHandler) 787 } 788 789 func TestInvokeHandlerPanic(t *testing.T) { 790 callMethod := "mock" 791 var opts []Option 792 mwExec := false 793 opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { 794 return func(ctx context.Context, req, resp interface{}) (err error) { 795 mwExec = true 796 return next(ctx, req, resp) 797 } 798 })) 799 opts = append(opts, WithCodec(&mockCodec{})) 800 transHdlrFact := &mockSvrTransHandlerFactory{} 801 exitCh := make(chan bool) 802 var ln net.Listener 803 transSvr := &mocks.MockTransServer{ 804 BootstrapServerFunc: func(net.Listener) error { 805 { 806 // mock server call 807 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) 808 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 809 recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) 810 recvMsg.NewData(callMethod) 811 sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) 812 813 _, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg) 814 test.Assert(t, strings.Contains(err.Error(), "happened in biz handler")) 815 } 816 <-exitCh 817 return nil 818 }, 819 ShutdownFunc: func() error { 820 ln.Close() 821 exitCh <- true 822 return nil 823 }, 824 CreateListenerFunc: func(addr net.Addr) (net.Listener, error) { 825 var err error 826 ln, err = net.Listen("tcp", ":8888") 827 return ln, err 828 }, 829 } 830 opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr))) 831 opts = append(opts, WithTransHandlerFactory(transHdlrFact)) 832 833 svr := NewServer(opts...) 834 time.AfterFunc(100*time.Millisecond, func() { 835 err := svr.Stop() 836 test.Assert(t, err == nil, err) 837 }) 838 serviceHandler := false 839 err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) { 840 serviceHandler = true 841 panic("test") 842 })) 843 test.Assert(t, err == nil) 844 err = svr.Run() 845 test.Assert(t, err == nil, err) 846 test.Assert(t, mwExec) 847 test.Assert(t, serviceHandler) 848 } 849 850 func TestRegisterService(t *testing.T) { 851 svr := NewServer() 852 time.AfterFunc(time.Second, func() { 853 err := svr.Stop() 854 test.Assert(t, err == nil, err) 855 }) 856 857 svr.Run() 858 859 test.PanicAt(t, func() { 860 _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 861 }, func(err interface{}) bool { 862 if errMsg, ok := err.(string); ok { 863 return strings.Contains(errMsg, "server is running") 864 } 865 return true 866 }) 867 svr.Stop() 868 869 svr = NewServer() 870 time.AfterFunc(time.Second, func() { 871 err := svr.Stop() 872 test.Assert(t, err == nil, err) 873 }) 874 875 test.PanicAt(t, func() { 876 _ = svr.RegisterService(nil, mocks.MyServiceHandler()) 877 }, func(err interface{}) bool { 878 if errMsg, ok := err.(string); ok { 879 return strings.Contains(errMsg, "svcInfo is nil") 880 } 881 return true 882 }) 883 884 test.PanicAt(t, func() { 885 _ = svr.RegisterService(mocks.ServiceInfo(), nil) 886 }, func(err interface{}) bool { 887 if errMsg, ok := err.(string); ok { 888 return strings.Contains(errMsg, "handler is nil") 889 } 890 return true 891 }) 892 893 test.PanicAt(t, func() { 894 _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService()) 895 _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 896 }, func(err interface{}) bool { 897 if errMsg, ok := err.(string); ok { 898 return strings.Contains(errMsg, "Service[MockService] is already defined") 899 } 900 return true 901 }) 902 903 test.PanicAt(t, func() { 904 _ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService()) 905 }, func(err interface{}) bool { 906 if errMsg, ok := err.(string); ok { 907 return strings.Contains(errMsg, "multiple fallback services cannot be registered") 908 } 909 return true 910 }) 911 svr.Stop() 912 913 svr = NewServer() 914 time.AfterFunc(time.Second, func() { 915 err := svr.Stop() 916 test.Assert(t, err == nil, err) 917 }) 918 919 _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 920 _ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler()) 921 err := svr.Run() 922 test.Assert(t, err != nil) 923 test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified") 924 svr.Stop() 925 } 926 927 type noopMetahandler struct{} 928 929 func (noopMetahandler) WriteMeta(ctx context.Context, msg remote.Message) (context.Context, error) { 930 return ctx, nil 931 } 932 933 func (noopMetahandler) ReadMeta(ctx context.Context, msg remote.Message) (context.Context, error) { 934 return ctx, nil 935 } 936 func (noopMetahandler) OnConnectStream(ctx context.Context) (context.Context, error) { return ctx, nil } 937 func (noopMetahandler) OnReadStream(ctx context.Context) (context.Context, error) { return ctx, nil } 938 939 type mockSvrTransHandlerFactory struct { 940 hdlr remote.ServerTransHandler 941 } 942 943 func (f *mockSvrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { 944 f.hdlr, _ = trans.NewDefaultSvrTransHandler(opt, &mockExtension{}) 945 return f.hdlr, nil 946 } 947 948 type mockExtension struct{} 949 950 func (m mockExtension) SetReadTimeout(ctx context.Context, conn net.Conn, cfg rpcinfo.RPCConfig, role remote.RPCRole) { 951 } 952 953 func (m mockExtension) NewWriteByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 954 return remote.NewWriterBuffer(0) 955 } 956 957 func (m mockExtension) NewReadByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 958 return remote.NewReaderBuffer(nil) 959 } 960 961 func (m mockExtension) ReleaseBuffer(buffer remote.ByteBuffer, err error) error { 962 return nil 963 } 964 965 func (m mockExtension) IsTimeoutErr(err error) bool { 966 return false 967 } 968 969 func (m mockExtension) IsRemoteClosedErr(err error) bool { 970 return false 971 } 972 973 type mockCodec struct { 974 EncodeFunc func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error 975 DecodeFunc func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error 976 } 977 978 func (m *mockCodec) Name() string { 979 return "Mock" 980 } 981 982 func (m *mockCodec) Encode(ctx context.Context, msg remote.Message, out remote.ByteBuffer) (err error) { 983 if m.EncodeFunc != nil { 984 return m.EncodeFunc(ctx, msg, out) 985 } 986 return 987 } 988 989 func (m *mockCodec) Decode(ctx context.Context, msg remote.Message, in remote.ByteBuffer) (err error) { 990 if m.DecodeFunc != nil { 991 return m.DecodeFunc(ctx, msg, in) 992 } 993 return 994 } 995 996 func TestDuplicatedRegisterInfoPanic(t *testing.T) { 997 svcs := newServices() 998 svcs.addService(mocks.ServiceInfo(), nil, &RegisterOptions{}) 999 s := &server{ 1000 opt: internal_server.NewOptions(nil), 1001 svcs: svcs, 1002 } 1003 s.init() 1004 1005 test.Panic(t, func() { 1006 _ = s.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) 1007 }) 1008 } 1009 1010 func TestRunServiceWithoutSvcInfo(t *testing.T) { 1011 svr := NewServer() 1012 time.AfterFunc(100*time.Millisecond, func() { 1013 _ = svr.Stop() 1014 }) 1015 err := svr.Run() 1016 test.Assert(t, err != nil) 1017 test.Assert(t, strings.Contains(err.Error(), "no service")) 1018 } 1019 1020 func inboundDeepEqual(inbound1, inbound2 []remote.InboundHandler) bool { 1021 if len(inbound1) != len(inbound2) { 1022 return false 1023 } 1024 for i := 0; i < len(inbound1); i++ { 1025 if !bound.DeepEqual(inbound1[i], inbound2[i]) { 1026 return false 1027 } 1028 } 1029 return true 1030 } 1031 1032 func withGRPCTransport() Option { 1033 return Option{F: func(o *internal_server.Options, di *utils.Slice) { 1034 o.RemoteOpt.SvrHandlerFactory = nphttp2.NewSvrTransHandlerFactory() 1035 }} 1036 }