github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/server_handler_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package netpollmux 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "sync" 24 "sync/atomic" 25 "testing" 26 "time" 27 28 "github.com/cloudwego/netpoll" 29 30 "github.com/cloudwego/kitex/internal/mocks" 31 "github.com/cloudwego/kitex/internal/test" 32 "github.com/cloudwego/kitex/pkg/remote" 33 "github.com/cloudwego/kitex/pkg/remote/codec" 34 "github.com/cloudwego/kitex/pkg/rpcinfo" 35 "github.com/cloudwego/kitex/pkg/serviceinfo" 36 "github.com/cloudwego/kitex/pkg/utils" 37 ) 38 39 var ( 40 opt *remote.ServerOption 41 rwTimeout = time.Second 42 addrStr = "test addr" 43 addr = utils.NewNetAddr("tcp", addrStr) 44 method = "mock" 45 46 svcInfo = mocks.ServiceInfo() 47 svcSearchMap = map[string]*serviceinfo.ServiceInfo{ 48 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, 49 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, 50 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, 51 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, 52 mocks.MockMethod: svcInfo, 53 mocks.MockExceptionMethod: svcInfo, 54 mocks.MockErrorMethod: svcInfo, 55 mocks.MockOnewayMethod: svcInfo, 56 } 57 ) 58 59 func newTestRpcInfo() rpcinfo.RPCInfo { 60 fromInfo := rpcinfo.EmptyEndpointInfo() 61 rpcCfg := rpcinfo.NewRPCConfig() 62 mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg) 63 mCfg.SetReadWriteTimeout(rwTimeout) 64 ink := rpcinfo.NewInvocation("", method) 65 rpcStat := rpcinfo.NewRPCStats() 66 67 rpcInfo := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat) 68 rpcinfo.AsMutableEndpointInfo(rpcInfo.From()).SetAddress(addr) 69 70 return rpcInfo 71 } 72 73 func init() { 74 body := "hello world" 75 rpcInfo := newTestRpcInfo() 76 77 opt = &remote.ServerOption{ 78 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 79 return rpcInfo 80 }, 81 Codec: &MockCodec{ 82 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 83 r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body) 84 _, err := out.WriteBinary(r.Bytes()) 85 return err 86 }, 87 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 88 in.Skip(3 * codec.Size32) 89 _, err := in.ReadString(len(body)) 90 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 91 return err 92 }, 93 }, 94 SvcSearchMap: svcSearchMap, 95 TargetSvcInfo: svcInfo, 96 TracerCtl: &rpcinfo.TraceController{}, 97 ReadWriteTimeout: rwTimeout, 98 } 99 } 100 101 // TestNewTransHandler test new a ServerTransHandler 102 func TestNewTransHandler(t *testing.T) { 103 handler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{}) 104 test.Assert(t, err == nil, err) 105 test.Assert(t, handler != nil) 106 } 107 108 // TestOnActive test ServerTransHandler OnActive 109 func TestOnActive(t *testing.T) { 110 // 1. prepare mock data 111 var readTimeout time.Duration 112 conn := &MockNetpollConn{ 113 SetReadTimeoutFunc: func(timeout time.Duration) (e error) { 114 readTimeout = timeout 115 return nil 116 }, 117 Conn: mocks.Conn{ 118 RemoteAddrFunc: func() (r net.Addr) { 119 return addr 120 }, 121 }, 122 } 123 124 // 2. test 125 ctx := context.Background() 126 127 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 128 129 ctx, err := svrTransHdlr.OnActive(ctx, conn) 130 test.Assert(t, ctx != nil, ctx) 131 test.Assert(t, err == nil, err) 132 muxSvrCon, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 133 test.Assert(t, muxSvrCon != nil) 134 test.Assert(t, readTimeout == rwTimeout, readTimeout, rwTimeout) 135 } 136 137 // TestMuxSvrWrite test ServerTransHandler Write 138 func TestMuxSvrWrite(t *testing.T) { 139 // 1. prepare mock data 140 npconn := &MockNetpollConn{ 141 Conn: mocks.Conn{ 142 RemoteAddrFunc: func() (r net.Addr) { 143 return addr 144 }, 145 }, 146 } 147 pool := &sync.Pool{} 148 muxSvrCon := newMuxSvrConn(npconn, pool) 149 test.Assert(t, muxSvrCon != nil) 150 151 ctx := context.Background() 152 rpcInfo := newTestRpcInfo() 153 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 154 155 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 156 157 msg := &MockMessage{ 158 RPCInfoFunc: func() rpcinfo.RPCInfo { 159 return rpcInfo 160 }, 161 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 162 return &serviceinfo.ServiceInfo{ 163 Methods: map[string]serviceinfo.MethodInfo{ 164 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 165 }, 166 } 167 }, 168 } 169 170 // 2. test 171 ri := rpcinfo.GetRPCInfo(ctx) 172 test.Assert(t, ri != nil, ri) 173 174 ctx, err := svrTransHdlr.Write(ctx, muxSvrCon, msg) 175 test.Assert(t, ctx != nil, ctx) 176 test.Assert(t, err == nil, err) 177 } 178 179 // TestMuxSvrOnRead test ServerTransHandler OnRead 180 func TestMuxSvrOnRead(t *testing.T) { 181 var isWriteBufFlushed atomic.Value 182 var isReaderBufReleased atomic.Value 183 var isInvoked atomic.Value 184 185 buf := netpoll.NewLinkBuffer(1024) 186 npconn := &MockNetpollConn{ 187 ReaderFunc: func() (r netpoll.Reader) { 188 isReaderBufReleased.Store(1) 189 return buf 190 }, 191 WriterFunc: func() (r netpoll.Writer) { 192 isWriteBufFlushed.Store(1) 193 return buf 194 }, 195 Conn: mocks.Conn{ 196 RemoteAddrFunc: func() (r net.Addr) { 197 return addr 198 }, 199 }, 200 } 201 202 ctx := context.Background() 203 rpcInfo := newTestRpcInfo() 204 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 205 206 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 207 208 msg := &MockMessage{ 209 RPCInfoFunc: func() rpcinfo.RPCInfo { 210 return rpcInfo 211 }, 212 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 213 return &serviceinfo.ServiceInfo{ 214 Methods: map[string]serviceinfo.MethodInfo{ 215 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 216 }, 217 } 218 }, 219 } 220 221 pool := &sync.Pool{} 222 muxSvrCon := newMuxSvrConn(npconn, pool) 223 224 var err error 225 226 ri := rpcinfo.GetRPCInfo(ctx) 227 test.Assert(t, ri != nil, ri) 228 229 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 230 test.Assert(t, ctx != nil, ctx) 231 test.Assert(t, err == nil, err) 232 233 time.Sleep(10 * time.Millisecond) 234 buf.Flush() 235 test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len()) 236 237 ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon) 238 test.Assert(t, ctx != nil, ctx) 239 test.Assert(t, err == nil, err) 240 muxSvrConFromCtx, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 241 test.Assert(t, muxSvrConFromCtx != nil) 242 243 pl := remote.NewTransPipeline(svrTransHdlr) 244 svrTransHdlr.SetPipeline(pl) 245 246 if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { 247 setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { 248 isInvoked.Store(1) 249 return nil 250 }) 251 } 252 253 err = svrTransHdlr.OnRead(ctx, npconn) 254 test.Assert(t, err == nil, err) 255 time.Sleep(50 * time.Millisecond) 256 257 test.Assert(t, isReaderBufReleased.Load() == 1) 258 test.Assert(t, isWriteBufFlushed.Load() == 1) 259 test.Assert(t, isInvoked.Load() == 1) 260 } 261 262 // TestPanicAfterMuxSvrOnRead test have panic after read 263 func TestPanicAfterMuxSvrOnRead(t *testing.T) { 264 // 1. prepare mock data 265 var isWriteBufFlushed bool 266 var isReaderBufReleased bool 267 268 buf := netpoll.NewLinkBuffer(1024) 269 conn := &MockNetpollConn{ 270 Conn: mocks.Conn{ 271 RemoteAddrFunc: func() (r net.Addr) { 272 return addr 273 }, 274 CloseFunc: func() (e error) { 275 return nil 276 }, 277 }, 278 ReaderFunc: func() (r netpoll.Reader) { 279 isReaderBufReleased = true 280 return buf 281 }, 282 WriterFunc: func() (r netpoll.Writer) { 283 isWriteBufFlushed = true 284 return buf 285 }, 286 IsActiveFunc: func() (r bool) { 287 return true 288 }, 289 } 290 291 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 292 rpcInfo := newTestRpcInfo() 293 294 // pipeline nil panic 295 svrTransHdlr.SetPipeline(nil) 296 297 msg := &MockMessage{ 298 RPCInfoFunc: func() rpcinfo.RPCInfo { 299 return rpcInfo 300 }, 301 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 302 return &serviceinfo.ServiceInfo{ 303 Methods: map[string]serviceinfo.MethodInfo{ 304 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 305 }, 306 } 307 }, 308 } 309 310 pool := &sync.Pool{} 311 muxSvrCon := newMuxSvrConn(conn, pool) 312 313 // 2. test 314 var err error 315 ctx := context.Background() 316 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 317 318 ri := rpcinfo.GetRPCInfo(ctx) 319 test.Assert(t, ri != nil, ri) 320 321 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 322 test.Assert(t, ctx != nil, ctx) 323 test.Assert(t, err == nil, err) 324 325 time.Sleep(5 * time.Millisecond) 326 buf.Flush() 327 test.Assert(t, conn.Reader().Len() > 0, conn.Reader().Len()) 328 329 ctx, err = svrTransHdlr.OnActive(ctx, conn) 330 test.Assert(t, ctx != nil, ctx) 331 test.Assert(t, err == nil, err) 332 333 err = svrTransHdlr.OnRead(ctx, conn) 334 time.Sleep(50 * time.Millisecond) 335 test.Assert(t, err == nil, err) 336 test.Assert(t, isReaderBufReleased) 337 test.Assert(t, isWriteBufFlushed) 338 } 339 340 // TestRecoverAfterOnReadPanic test tryRecover after read panic 341 func TestRecoverAfterOnReadPanic(t *testing.T) { 342 var isWriteBufFlushed bool 343 var isReaderBufReleased bool 344 var isClosed bool 345 buf := netpoll.NewLinkBuffer(1024) 346 347 conn := &MockNetpollConn{ 348 Conn: mocks.Conn{ 349 RemoteAddrFunc: func() (r net.Addr) { 350 return addr 351 }, 352 CloseFunc: func() (e error) { 353 isClosed = true 354 return nil 355 }, 356 }, 357 ReaderFunc: func() (r netpoll.Reader) { 358 isReaderBufReleased = true 359 return buf 360 }, 361 WriterFunc: func() (r netpoll.Writer) { 362 isWriteBufFlushed = true 363 return buf 364 }, 365 IsActiveFunc: func() (r bool) { 366 return true 367 }, 368 } 369 370 rpcInfo := newTestRpcInfo() 371 372 msg := &MockMessage{ 373 RPCInfoFunc: func() rpcinfo.RPCInfo { 374 return rpcInfo 375 }, 376 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 377 return &serviceinfo.ServiceInfo{ 378 Methods: map[string]serviceinfo.MethodInfo{ 379 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 380 }, 381 } 382 }, 383 } 384 385 pool := &sync.Pool{} 386 muxSvrCon := newMuxSvrConn(conn, pool) 387 388 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 389 390 var err error 391 ctx := context.Background() 392 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 393 394 ri := rpcinfo.GetRPCInfo(ctx) 395 test.Assert(t, ri != nil, ri) 396 397 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 398 test.Assert(t, ctx != nil, ctx) 399 test.Assert(t, err == nil, err) 400 401 time.Sleep(5 * time.Millisecond) 402 buf.Flush() 403 test.Assert(t, conn.Reader().Len() > 0, conn.Reader().Len()) 404 405 ctx, err = svrTransHdlr.OnActive(ctx, conn) 406 test.Assert(t, ctx != nil, ctx) 407 test.Assert(t, err == nil, err) 408 409 // test recover after panic 410 err = svrTransHdlr.OnRead(ctx, nil) 411 test.Assert(t, err == nil, err) 412 test.Assert(t, isReaderBufReleased) 413 test.Assert(t, isWriteBufFlushed) 414 test.Assert(t, !isClosed) 415 416 // test recover after panic 417 err = svrTransHdlr.OnRead(ctx, &MockNetpollConn{}) 418 test.Assert(t, err == nil, err) 419 test.Assert(t, isReaderBufReleased) 420 test.Assert(t, isWriteBufFlushed) 421 test.Assert(t, !isClosed) 422 } 423 424 // TestOnError test Invoke has err 425 func TestInvokeError(t *testing.T) { 426 var isReaderBufReleased bool 427 var isWriteBufFlushed atomic.Value 428 var invokedErr atomic.Value 429 430 buf := netpoll.NewLinkBuffer(1024) 431 npconn := &MockNetpollConn{ 432 ReaderFunc: func() (r netpoll.Reader) { 433 isReaderBufReleased = true 434 return buf 435 }, 436 WriterFunc: func() (r netpoll.Writer) { 437 isWriteBufFlushed.Store(1) 438 return buf 439 }, 440 Conn: mocks.Conn{ 441 RemoteAddrFunc: func() (r net.Addr) { 442 return addr 443 }, 444 CloseFunc: func() (e error) { 445 return nil 446 }, 447 }, 448 } 449 450 rpcInfo := newTestRpcInfo() 451 452 msg := &MockMessage{ 453 RPCInfoFunc: func() rpcinfo.RPCInfo { 454 return rpcInfo 455 }, 456 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 457 return &serviceinfo.ServiceInfo{ 458 Methods: map[string]serviceinfo.MethodInfo{ 459 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 460 }, 461 } 462 }, 463 } 464 465 body := "hello world" 466 opt := &remote.ServerOption{ 467 InitOrResetRPCInfoFunc: func(rpcInfo rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 468 fromInfo := rpcinfo.EmptyEndpointInfo() 469 rpcCfg := rpcinfo.NewRPCConfig() 470 mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg) 471 mCfg.SetReadWriteTimeout(rwTimeout) 472 ink := rpcinfo.NewInvocation("", method) 473 rpcStat := rpcinfo.NewRPCStats() 474 nri := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat) 475 rpcinfo.AsMutableEndpointInfo(nri.From()).SetAddress(addr) 476 return nri 477 }, 478 Codec: &MockCodec{ 479 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 480 r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body) 481 _, err := out.WriteBinary(r.Bytes()) 482 return err 483 }, 484 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 485 in.Skip(3 * codec.Size32) 486 _, err := in.ReadString(len(body)) 487 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 488 return err 489 }, 490 }, 491 SvcSearchMap: svcSearchMap, 492 TargetSvcInfo: svcInfo, 493 TracerCtl: &rpcinfo.TraceController{}, 494 ReadWriteTimeout: rwTimeout, 495 } 496 497 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 498 499 pool := &sync.Pool{ 500 New: func() interface{} { 501 // init rpcinfo 502 ri := opt.InitOrResetRPCInfoFunc(nil, npconn.RemoteAddr()) 503 return ri 504 }, 505 } 506 muxSvrCon := newMuxSvrConn(npconn, pool) 507 508 var err error 509 ctx := context.Background() 510 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 511 512 ri := rpcinfo.GetRPCInfo(ctx) 513 test.Assert(t, ri != nil, ri) 514 515 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 516 test.Assert(t, ctx != nil, ctx) 517 test.Assert(t, err == nil, err) 518 519 time.Sleep(5 * time.Millisecond) 520 buf.Flush() 521 test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len()) 522 523 ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon) 524 test.Assert(t, ctx != nil, ctx) 525 test.Assert(t, err == nil, err) 526 muxSvrCon, _ = ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 527 test.Assert(t, muxSvrCon != nil) 528 529 pl := remote.NewTransPipeline(svrTransHdlr) 530 svrTransHdlr.SetPipeline(pl) 531 532 if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { 533 setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { 534 err = errors.New("mock invoke err test") 535 invokedErr.Store(err) 536 return err 537 }) 538 } 539 540 err = svrTransHdlr.OnRead(ctx, npconn) 541 time.Sleep(50 * time.Millisecond) 542 test.Assert(t, err == nil, err) 543 test.Assert(t, isReaderBufReleased) 544 test.Assert(t, invokedErr.Load() != nil) 545 test.Assert(t, isWriteBufFlushed.Load() == 1) 546 } 547 548 // TestOnError test OnError method 549 func TestOnError(t *testing.T) { 550 // 1. prepare mock data 551 buf := netpoll.NewLinkBuffer(1) 552 conn := &MockNetpollConn{ 553 Conn: mocks.Conn{ 554 RemoteAddrFunc: func() (r net.Addr) { 555 return addr 556 }, 557 CloseFunc: func() (e error) { 558 return nil 559 }, 560 }, 561 ReaderFunc: func() (r netpoll.Reader) { 562 return buf 563 }, 564 WriterFunc: func() (r netpoll.Writer) { 565 return buf 566 }, 567 } 568 569 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 570 571 // 2. test 572 ctx := context.Background() 573 rpcInfo := newTestRpcInfo() 574 575 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 576 svrTransHdlr.OnError(ctx, errors.New("test mock err"), conn) 577 svrTransHdlr.OnError(ctx, netpoll.ErrConnClosed, conn) 578 } 579 580 // TestInvokeNoMethod test invoke no method 581 func TestInvokeNoMethod(t *testing.T) { 582 var isWriteBufFlushed atomic.Value 583 var isReaderBufReleased bool 584 var isInvoked bool 585 586 buf := netpoll.NewLinkBuffer(1024) 587 npconn := &MockNetpollConn{ 588 ReaderFunc: func() (r netpoll.Reader) { 589 isReaderBufReleased = true 590 return buf 591 }, 592 WriterFunc: func() (r netpoll.Writer) { 593 isWriteBufFlushed.Store(1) 594 return buf 595 }, 596 Conn: mocks.Conn{ 597 RemoteAddrFunc: func() (r net.Addr) { 598 return addr 599 }, 600 CloseFunc: func() (e error) { 601 return nil 602 }, 603 }, 604 } 605 606 rpcInfo := newTestRpcInfo() 607 608 msg := &MockMessage{ 609 RPCInfoFunc: func() rpcinfo.RPCInfo { 610 return rpcInfo 611 }, 612 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 613 return &serviceinfo.ServiceInfo{ 614 Methods: map[string]serviceinfo.MethodInfo{ 615 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 616 }, 617 } 618 }, 619 } 620 621 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt) 622 623 pool := &sync.Pool{} 624 muxSvrCon := newMuxSvrConn(npconn, pool) 625 626 var err error 627 ctx := context.Background() 628 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 629 630 ri := rpcinfo.GetRPCInfo(ctx) 631 test.Assert(t, ri != nil, ri) 632 633 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 634 test.Assert(t, ctx != nil, ctx) 635 test.Assert(t, err == nil, err) 636 637 time.Sleep(5 * time.Millisecond) 638 buf.Flush() 639 test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len()) 640 641 ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon) 642 test.Assert(t, ctx != nil, ctx) 643 test.Assert(t, err == nil, err) 644 muxSvrCon, _ = ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 645 test.Assert(t, muxSvrCon != nil) 646 647 pl := remote.NewTransPipeline(svrTransHdlr) 648 svrTransHdlr.SetPipeline(pl) 649 650 svcInfo = opt.TargetSvcInfo 651 delete(svcInfo.Methods, method) 652 653 if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { 654 setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { 655 isInvoked = true 656 return nil 657 }) 658 } 659 660 err = svrTransHdlr.OnRead(ctx, npconn) 661 time.Sleep(50 * time.Millisecond) 662 test.Assert(t, err == nil, err) 663 test.Assert(t, isReaderBufReleased) 664 test.Assert(t, isWriteBufFlushed.Load() == 1) 665 test.Assert(t, !isInvoked) 666 } 667 668 // TestMuxSvcOnReadHeartbeat test SvrTransHandler OnRead to process heartbeat 669 func TestMuxSvrOnReadHeartbeat(t *testing.T) { 670 var isWriteBufFlushed atomic.Value 671 var isReaderBufReleased atomic.Value 672 var isInvoked atomic.Value 673 674 buf := netpoll.NewLinkBuffer(1024) 675 npconn := &MockNetpollConn{ 676 ReaderFunc: func() (r netpoll.Reader) { 677 isReaderBufReleased.Store(1) 678 return buf 679 }, 680 WriterFunc: func() (r netpoll.Writer) { 681 isWriteBufFlushed.Store(1) 682 return buf 683 }, 684 Conn: mocks.Conn{ 685 RemoteAddrFunc: func() (r net.Addr) { 686 return addr 687 }, 688 }, 689 } 690 691 var heartbeatFlag bool 692 body := "non-heartbeat process" 693 ctx := context.Background() 694 rpcInfo := newTestRpcInfo() 695 ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo) 696 697 // use newOpt cause we need to add heartbeat logic to EncodeFunc and DecodeFunc 698 newOpt := &remote.ServerOption{ 699 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 700 return rpcInfo 701 }, 702 Codec: &MockCodec{ 703 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 704 if heartbeatFlag { 705 if msg.MessageType() != remote.Heartbeat { 706 return errors.New("response is not of MessageType Heartbeat") 707 } 708 return nil 709 } 710 r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body) 711 _, err := out.WriteBinary(r.Bytes()) 712 return err 713 }, 714 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 715 if heartbeatFlag { 716 msg.SetMessageType(remote.Heartbeat) 717 return nil 718 } 719 in.Skip(3 * codec.Size32) 720 _, err := in.ReadString(len(body)) 721 return err 722 }, 723 }, 724 SvcSearchMap: svcSearchMap, 725 TargetSvcInfo: svcInfo, 726 TracerCtl: &rpcinfo.TraceController{}, 727 ReadWriteTimeout: rwTimeout, 728 } 729 svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(newOpt) 730 731 msg := &MockMessage{ 732 RPCInfoFunc: func() rpcinfo.RPCInfo { 733 return rpcInfo 734 }, 735 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 736 return &serviceinfo.ServiceInfo{ 737 Methods: map[string]serviceinfo.MethodInfo{ 738 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 739 }, 740 } 741 }, 742 } 743 744 pool := &sync.Pool{} 745 muxSvrCon := newMuxSvrConn(npconn, pool) 746 747 var err error 748 749 ri := rpcinfo.GetRPCInfo(ctx) 750 test.Assert(t, ri != nil, ri) 751 752 ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg) 753 test.Assert(t, ctx != nil, ctx) 754 test.Assert(t, err == nil, err) 755 756 time.Sleep(10 * time.Millisecond) 757 buf.Flush() 758 test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len()) 759 760 ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon) 761 test.Assert(t, ctx != nil, ctx) 762 test.Assert(t, err == nil, err) 763 muxSvrConFromCtx, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 764 test.Assert(t, muxSvrConFromCtx != nil) 765 766 pl := remote.NewTransPipeline(svrTransHdlr) 767 svrTransHdlr.SetPipeline(pl) 768 769 if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok { 770 setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { 771 isInvoked.Store(1) 772 return nil 773 }) 774 } 775 776 // start the heartbeat processing 777 heartbeatFlag = true 778 err = svrTransHdlr.OnRead(ctx, npconn) 779 test.Assert(t, err == nil, err) 780 time.Sleep(50 * time.Millisecond) 781 782 test.Assert(t, isReaderBufReleased.Load() == 1) 783 test.Assert(t, isWriteBufFlushed.Load() == 1) 784 // InvokeHandleFunc has not been invoked 785 test.Assert(t, isInvoked.Load() == nil) 786 }