github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/server_handler_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpollmux
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"sync"
    24  	"sync/atomic"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/cloudwego/netpoll"
    29  
    30  	"github.com/cloudwego/kitex/internal/mocks"
    31  	"github.com/cloudwego/kitex/internal/test"
    32  	"github.com/cloudwego/kitex/pkg/remote"
    33  	"github.com/cloudwego/kitex/pkg/remote/codec"
    34  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    35  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    36  	"github.com/cloudwego/kitex/pkg/utils"
    37  )
    38  
    39  var (
    40  	opt       *remote.ServerOption
    41  	rwTimeout = time.Second
    42  	addrStr   = "test addr"
    43  	addr      = utils.NewNetAddr("tcp", addrStr)
    44  	method    = "mock"
    45  
    46  	svcInfo      = mocks.ServiceInfo()
    47  	svcSearchMap = map[string]*serviceinfo.ServiceInfo{
    48  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod):          svcInfo,
    49  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
    50  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod):     svcInfo,
    51  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod):    svcInfo,
    52  		mocks.MockMethod:          svcInfo,
    53  		mocks.MockExceptionMethod: svcInfo,
    54  		mocks.MockErrorMethod:     svcInfo,
    55  		mocks.MockOnewayMethod:    svcInfo,
    56  	}
    57  )
    58  
    59  func newTestRpcInfo() rpcinfo.RPCInfo {
    60  	fromInfo := rpcinfo.EmptyEndpointInfo()
    61  	rpcCfg := rpcinfo.NewRPCConfig()
    62  	mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg)
    63  	mCfg.SetReadWriteTimeout(rwTimeout)
    64  	ink := rpcinfo.NewInvocation("", method)
    65  	rpcStat := rpcinfo.NewRPCStats()
    66  
    67  	rpcInfo := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat)
    68  	rpcinfo.AsMutableEndpointInfo(rpcInfo.From()).SetAddress(addr)
    69  
    70  	return rpcInfo
    71  }
    72  
    73  func init() {
    74  	body := "hello world"
    75  	rpcInfo := newTestRpcInfo()
    76  
    77  	opt = &remote.ServerOption{
    78  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
    79  			return rpcInfo
    80  		},
    81  		Codec: &MockCodec{
    82  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
    83  				r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body)
    84  				_, err := out.WriteBinary(r.Bytes())
    85  				return err
    86  			},
    87  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
    88  				in.Skip(3 * codec.Size32)
    89  				_, err := in.ReadString(len(body))
    90  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
    91  				return err
    92  			},
    93  		},
    94  		SvcSearchMap:     svcSearchMap,
    95  		TargetSvcInfo:    svcInfo,
    96  		TracerCtl:        &rpcinfo.TraceController{},
    97  		ReadWriteTimeout: rwTimeout,
    98  	}
    99  }
   100  
   101  // TestNewTransHandler test new a ServerTransHandler
   102  func TestNewTransHandler(t *testing.T) {
   103  	handler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{})
   104  	test.Assert(t, err == nil, err)
   105  	test.Assert(t, handler != nil)
   106  }
   107  
   108  // TestOnActive test ServerTransHandler OnActive
   109  func TestOnActive(t *testing.T) {
   110  	// 1. prepare mock data
   111  	var readTimeout time.Duration
   112  	conn := &MockNetpollConn{
   113  		SetReadTimeoutFunc: func(timeout time.Duration) (e error) {
   114  			readTimeout = timeout
   115  			return nil
   116  		},
   117  		Conn: mocks.Conn{
   118  			RemoteAddrFunc: func() (r net.Addr) {
   119  				return addr
   120  			},
   121  		},
   122  	}
   123  
   124  	// 2. test
   125  	ctx := context.Background()
   126  
   127  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   128  
   129  	ctx, err := svrTransHdlr.OnActive(ctx, conn)
   130  	test.Assert(t, ctx != nil, ctx)
   131  	test.Assert(t, err == nil, err)
   132  	muxSvrCon, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   133  	test.Assert(t, muxSvrCon != nil)
   134  	test.Assert(t, readTimeout == rwTimeout, readTimeout, rwTimeout)
   135  }
   136  
   137  // TestMuxSvrWrite test ServerTransHandler Write
   138  func TestMuxSvrWrite(t *testing.T) {
   139  	// 1. prepare mock data
   140  	npconn := &MockNetpollConn{
   141  		Conn: mocks.Conn{
   142  			RemoteAddrFunc: func() (r net.Addr) {
   143  				return addr
   144  			},
   145  		},
   146  	}
   147  	pool := &sync.Pool{}
   148  	muxSvrCon := newMuxSvrConn(npconn, pool)
   149  	test.Assert(t, muxSvrCon != nil)
   150  
   151  	ctx := context.Background()
   152  	rpcInfo := newTestRpcInfo()
   153  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   154  
   155  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   156  
   157  	msg := &MockMessage{
   158  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   159  			return rpcInfo
   160  		},
   161  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   162  			return &serviceinfo.ServiceInfo{
   163  				Methods: map[string]serviceinfo.MethodInfo{
   164  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   165  				},
   166  			}
   167  		},
   168  	}
   169  
   170  	// 2. test
   171  	ri := rpcinfo.GetRPCInfo(ctx)
   172  	test.Assert(t, ri != nil, ri)
   173  
   174  	ctx, err := svrTransHdlr.Write(ctx, muxSvrCon, msg)
   175  	test.Assert(t, ctx != nil, ctx)
   176  	test.Assert(t, err == nil, err)
   177  }
   178  
   179  // TestMuxSvrOnRead test ServerTransHandler OnRead
   180  func TestMuxSvrOnRead(t *testing.T) {
   181  	var isWriteBufFlushed atomic.Value
   182  	var isReaderBufReleased atomic.Value
   183  	var isInvoked atomic.Value
   184  
   185  	buf := netpoll.NewLinkBuffer(1024)
   186  	npconn := &MockNetpollConn{
   187  		ReaderFunc: func() (r netpoll.Reader) {
   188  			isReaderBufReleased.Store(1)
   189  			return buf
   190  		},
   191  		WriterFunc: func() (r netpoll.Writer) {
   192  			isWriteBufFlushed.Store(1)
   193  			return buf
   194  		},
   195  		Conn: mocks.Conn{
   196  			RemoteAddrFunc: func() (r net.Addr) {
   197  				return addr
   198  			},
   199  		},
   200  	}
   201  
   202  	ctx := context.Background()
   203  	rpcInfo := newTestRpcInfo()
   204  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   205  
   206  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   207  
   208  	msg := &MockMessage{
   209  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   210  			return rpcInfo
   211  		},
   212  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   213  			return &serviceinfo.ServiceInfo{
   214  				Methods: map[string]serviceinfo.MethodInfo{
   215  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   216  				},
   217  			}
   218  		},
   219  	}
   220  
   221  	pool := &sync.Pool{}
   222  	muxSvrCon := newMuxSvrConn(npconn, pool)
   223  
   224  	var err error
   225  
   226  	ri := rpcinfo.GetRPCInfo(ctx)
   227  	test.Assert(t, ri != nil, ri)
   228  
   229  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   230  	test.Assert(t, ctx != nil, ctx)
   231  	test.Assert(t, err == nil, err)
   232  
   233  	time.Sleep(10 * time.Millisecond)
   234  	buf.Flush()
   235  	test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len())
   236  
   237  	ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon)
   238  	test.Assert(t, ctx != nil, ctx)
   239  	test.Assert(t, err == nil, err)
   240  	muxSvrConFromCtx, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   241  	test.Assert(t, muxSvrConFromCtx != nil)
   242  
   243  	pl := remote.NewTransPipeline(svrTransHdlr)
   244  	svrTransHdlr.SetPipeline(pl)
   245  
   246  	if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok {
   247  		setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
   248  			isInvoked.Store(1)
   249  			return nil
   250  		})
   251  	}
   252  
   253  	err = svrTransHdlr.OnRead(ctx, npconn)
   254  	test.Assert(t, err == nil, err)
   255  	time.Sleep(50 * time.Millisecond)
   256  
   257  	test.Assert(t, isReaderBufReleased.Load() == 1)
   258  	test.Assert(t, isWriteBufFlushed.Load() == 1)
   259  	test.Assert(t, isInvoked.Load() == 1)
   260  }
   261  
   262  // TestPanicAfterMuxSvrOnRead test have panic after read
   263  func TestPanicAfterMuxSvrOnRead(t *testing.T) {
   264  	// 1. prepare mock data
   265  	var isWriteBufFlushed bool
   266  	var isReaderBufReleased bool
   267  
   268  	buf := netpoll.NewLinkBuffer(1024)
   269  	conn := &MockNetpollConn{
   270  		Conn: mocks.Conn{
   271  			RemoteAddrFunc: func() (r net.Addr) {
   272  				return addr
   273  			},
   274  			CloseFunc: func() (e error) {
   275  				return nil
   276  			},
   277  		},
   278  		ReaderFunc: func() (r netpoll.Reader) {
   279  			isReaderBufReleased = true
   280  			return buf
   281  		},
   282  		WriterFunc: func() (r netpoll.Writer) {
   283  			isWriteBufFlushed = true
   284  			return buf
   285  		},
   286  		IsActiveFunc: func() (r bool) {
   287  			return true
   288  		},
   289  	}
   290  
   291  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   292  	rpcInfo := newTestRpcInfo()
   293  
   294  	// pipeline nil panic
   295  	svrTransHdlr.SetPipeline(nil)
   296  
   297  	msg := &MockMessage{
   298  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   299  			return rpcInfo
   300  		},
   301  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   302  			return &serviceinfo.ServiceInfo{
   303  				Methods: map[string]serviceinfo.MethodInfo{
   304  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   305  				},
   306  			}
   307  		},
   308  	}
   309  
   310  	pool := &sync.Pool{}
   311  	muxSvrCon := newMuxSvrConn(conn, pool)
   312  
   313  	// 2. test
   314  	var err error
   315  	ctx := context.Background()
   316  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   317  
   318  	ri := rpcinfo.GetRPCInfo(ctx)
   319  	test.Assert(t, ri != nil, ri)
   320  
   321  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   322  	test.Assert(t, ctx != nil, ctx)
   323  	test.Assert(t, err == nil, err)
   324  
   325  	time.Sleep(5 * time.Millisecond)
   326  	buf.Flush()
   327  	test.Assert(t, conn.Reader().Len() > 0, conn.Reader().Len())
   328  
   329  	ctx, err = svrTransHdlr.OnActive(ctx, conn)
   330  	test.Assert(t, ctx != nil, ctx)
   331  	test.Assert(t, err == nil, err)
   332  
   333  	err = svrTransHdlr.OnRead(ctx, conn)
   334  	time.Sleep(50 * time.Millisecond)
   335  	test.Assert(t, err == nil, err)
   336  	test.Assert(t, isReaderBufReleased)
   337  	test.Assert(t, isWriteBufFlushed)
   338  }
   339  
   340  // TestRecoverAfterOnReadPanic test tryRecover after read panic
   341  func TestRecoverAfterOnReadPanic(t *testing.T) {
   342  	var isWriteBufFlushed bool
   343  	var isReaderBufReleased bool
   344  	var isClosed bool
   345  	buf := netpoll.NewLinkBuffer(1024)
   346  
   347  	conn := &MockNetpollConn{
   348  		Conn: mocks.Conn{
   349  			RemoteAddrFunc: func() (r net.Addr) {
   350  				return addr
   351  			},
   352  			CloseFunc: func() (e error) {
   353  				isClosed = true
   354  				return nil
   355  			},
   356  		},
   357  		ReaderFunc: func() (r netpoll.Reader) {
   358  			isReaderBufReleased = true
   359  			return buf
   360  		},
   361  		WriterFunc: func() (r netpoll.Writer) {
   362  			isWriteBufFlushed = true
   363  			return buf
   364  		},
   365  		IsActiveFunc: func() (r bool) {
   366  			return true
   367  		},
   368  	}
   369  
   370  	rpcInfo := newTestRpcInfo()
   371  
   372  	msg := &MockMessage{
   373  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   374  			return rpcInfo
   375  		},
   376  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   377  			return &serviceinfo.ServiceInfo{
   378  				Methods: map[string]serviceinfo.MethodInfo{
   379  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   380  				},
   381  			}
   382  		},
   383  	}
   384  
   385  	pool := &sync.Pool{}
   386  	muxSvrCon := newMuxSvrConn(conn, pool)
   387  
   388  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   389  
   390  	var err error
   391  	ctx := context.Background()
   392  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   393  
   394  	ri := rpcinfo.GetRPCInfo(ctx)
   395  	test.Assert(t, ri != nil, ri)
   396  
   397  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   398  	test.Assert(t, ctx != nil, ctx)
   399  	test.Assert(t, err == nil, err)
   400  
   401  	time.Sleep(5 * time.Millisecond)
   402  	buf.Flush()
   403  	test.Assert(t, conn.Reader().Len() > 0, conn.Reader().Len())
   404  
   405  	ctx, err = svrTransHdlr.OnActive(ctx, conn)
   406  	test.Assert(t, ctx != nil, ctx)
   407  	test.Assert(t, err == nil, err)
   408  
   409  	// test recover after panic
   410  	err = svrTransHdlr.OnRead(ctx, nil)
   411  	test.Assert(t, err == nil, err)
   412  	test.Assert(t, isReaderBufReleased)
   413  	test.Assert(t, isWriteBufFlushed)
   414  	test.Assert(t, !isClosed)
   415  
   416  	// test recover after panic
   417  	err = svrTransHdlr.OnRead(ctx, &MockNetpollConn{})
   418  	test.Assert(t, err == nil, err)
   419  	test.Assert(t, isReaderBufReleased)
   420  	test.Assert(t, isWriteBufFlushed)
   421  	test.Assert(t, !isClosed)
   422  }
   423  
   424  // TestOnError test Invoke has err
   425  func TestInvokeError(t *testing.T) {
   426  	var isReaderBufReleased bool
   427  	var isWriteBufFlushed atomic.Value
   428  	var invokedErr atomic.Value
   429  
   430  	buf := netpoll.NewLinkBuffer(1024)
   431  	npconn := &MockNetpollConn{
   432  		ReaderFunc: func() (r netpoll.Reader) {
   433  			isReaderBufReleased = true
   434  			return buf
   435  		},
   436  		WriterFunc: func() (r netpoll.Writer) {
   437  			isWriteBufFlushed.Store(1)
   438  			return buf
   439  		},
   440  		Conn: mocks.Conn{
   441  			RemoteAddrFunc: func() (r net.Addr) {
   442  				return addr
   443  			},
   444  			CloseFunc: func() (e error) {
   445  				return nil
   446  			},
   447  		},
   448  	}
   449  
   450  	rpcInfo := newTestRpcInfo()
   451  
   452  	msg := &MockMessage{
   453  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   454  			return rpcInfo
   455  		},
   456  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   457  			return &serviceinfo.ServiceInfo{
   458  				Methods: map[string]serviceinfo.MethodInfo{
   459  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   460  				},
   461  			}
   462  		},
   463  	}
   464  
   465  	body := "hello world"
   466  	opt := &remote.ServerOption{
   467  		InitOrResetRPCInfoFunc: func(rpcInfo rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
   468  			fromInfo := rpcinfo.EmptyEndpointInfo()
   469  			rpcCfg := rpcinfo.NewRPCConfig()
   470  			mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg)
   471  			mCfg.SetReadWriteTimeout(rwTimeout)
   472  			ink := rpcinfo.NewInvocation("", method)
   473  			rpcStat := rpcinfo.NewRPCStats()
   474  			nri := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat)
   475  			rpcinfo.AsMutableEndpointInfo(nri.From()).SetAddress(addr)
   476  			return nri
   477  		},
   478  		Codec: &MockCodec{
   479  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   480  				r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body)
   481  				_, err := out.WriteBinary(r.Bytes())
   482  				return err
   483  			},
   484  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   485  				in.Skip(3 * codec.Size32)
   486  				_, err := in.ReadString(len(body))
   487  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
   488  				return err
   489  			},
   490  		},
   491  		SvcSearchMap:     svcSearchMap,
   492  		TargetSvcInfo:    svcInfo,
   493  		TracerCtl:        &rpcinfo.TraceController{},
   494  		ReadWriteTimeout: rwTimeout,
   495  	}
   496  
   497  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   498  
   499  	pool := &sync.Pool{
   500  		New: func() interface{} {
   501  			// init rpcinfo
   502  			ri := opt.InitOrResetRPCInfoFunc(nil, npconn.RemoteAddr())
   503  			return ri
   504  		},
   505  	}
   506  	muxSvrCon := newMuxSvrConn(npconn, pool)
   507  
   508  	var err error
   509  	ctx := context.Background()
   510  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   511  
   512  	ri := rpcinfo.GetRPCInfo(ctx)
   513  	test.Assert(t, ri != nil, ri)
   514  
   515  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   516  	test.Assert(t, ctx != nil, ctx)
   517  	test.Assert(t, err == nil, err)
   518  
   519  	time.Sleep(5 * time.Millisecond)
   520  	buf.Flush()
   521  	test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len())
   522  
   523  	ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon)
   524  	test.Assert(t, ctx != nil, ctx)
   525  	test.Assert(t, err == nil, err)
   526  	muxSvrCon, _ = ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   527  	test.Assert(t, muxSvrCon != nil)
   528  
   529  	pl := remote.NewTransPipeline(svrTransHdlr)
   530  	svrTransHdlr.SetPipeline(pl)
   531  
   532  	if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok {
   533  		setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
   534  			err = errors.New("mock invoke err test")
   535  			invokedErr.Store(err)
   536  			return err
   537  		})
   538  	}
   539  
   540  	err = svrTransHdlr.OnRead(ctx, npconn)
   541  	time.Sleep(50 * time.Millisecond)
   542  	test.Assert(t, err == nil, err)
   543  	test.Assert(t, isReaderBufReleased)
   544  	test.Assert(t, invokedErr.Load() != nil)
   545  	test.Assert(t, isWriteBufFlushed.Load() == 1)
   546  }
   547  
   548  // TestOnError test OnError method
   549  func TestOnError(t *testing.T) {
   550  	// 1. prepare mock data
   551  	buf := netpoll.NewLinkBuffer(1)
   552  	conn := &MockNetpollConn{
   553  		Conn: mocks.Conn{
   554  			RemoteAddrFunc: func() (r net.Addr) {
   555  				return addr
   556  			},
   557  			CloseFunc: func() (e error) {
   558  				return nil
   559  			},
   560  		},
   561  		ReaderFunc: func() (r netpoll.Reader) {
   562  			return buf
   563  		},
   564  		WriterFunc: func() (r netpoll.Writer) {
   565  			return buf
   566  		},
   567  	}
   568  
   569  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   570  
   571  	// 2. test
   572  	ctx := context.Background()
   573  	rpcInfo := newTestRpcInfo()
   574  
   575  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   576  	svrTransHdlr.OnError(ctx, errors.New("test mock err"), conn)
   577  	svrTransHdlr.OnError(ctx, netpoll.ErrConnClosed, conn)
   578  }
   579  
   580  // TestInvokeNoMethod test invoke no method
   581  func TestInvokeNoMethod(t *testing.T) {
   582  	var isWriteBufFlushed atomic.Value
   583  	var isReaderBufReleased bool
   584  	var isInvoked bool
   585  
   586  	buf := netpoll.NewLinkBuffer(1024)
   587  	npconn := &MockNetpollConn{
   588  		ReaderFunc: func() (r netpoll.Reader) {
   589  			isReaderBufReleased = true
   590  			return buf
   591  		},
   592  		WriterFunc: func() (r netpoll.Writer) {
   593  			isWriteBufFlushed.Store(1)
   594  			return buf
   595  		},
   596  		Conn: mocks.Conn{
   597  			RemoteAddrFunc: func() (r net.Addr) {
   598  				return addr
   599  			},
   600  			CloseFunc: func() (e error) {
   601  				return nil
   602  			},
   603  		},
   604  	}
   605  
   606  	rpcInfo := newTestRpcInfo()
   607  
   608  	msg := &MockMessage{
   609  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   610  			return rpcInfo
   611  		},
   612  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   613  			return &serviceinfo.ServiceInfo{
   614  				Methods: map[string]serviceinfo.MethodInfo{
   615  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   616  				},
   617  			}
   618  		},
   619  	}
   620  
   621  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(opt)
   622  
   623  	pool := &sync.Pool{}
   624  	muxSvrCon := newMuxSvrConn(npconn, pool)
   625  
   626  	var err error
   627  	ctx := context.Background()
   628  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   629  
   630  	ri := rpcinfo.GetRPCInfo(ctx)
   631  	test.Assert(t, ri != nil, ri)
   632  
   633  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   634  	test.Assert(t, ctx != nil, ctx)
   635  	test.Assert(t, err == nil, err)
   636  
   637  	time.Sleep(5 * time.Millisecond)
   638  	buf.Flush()
   639  	test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len())
   640  
   641  	ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon)
   642  	test.Assert(t, ctx != nil, ctx)
   643  	test.Assert(t, err == nil, err)
   644  	muxSvrCon, _ = ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   645  	test.Assert(t, muxSvrCon != nil)
   646  
   647  	pl := remote.NewTransPipeline(svrTransHdlr)
   648  	svrTransHdlr.SetPipeline(pl)
   649  
   650  	svcInfo = opt.TargetSvcInfo
   651  	delete(svcInfo.Methods, method)
   652  
   653  	if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok {
   654  		setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
   655  			isInvoked = true
   656  			return nil
   657  		})
   658  	}
   659  
   660  	err = svrTransHdlr.OnRead(ctx, npconn)
   661  	time.Sleep(50 * time.Millisecond)
   662  	test.Assert(t, err == nil, err)
   663  	test.Assert(t, isReaderBufReleased)
   664  	test.Assert(t, isWriteBufFlushed.Load() == 1)
   665  	test.Assert(t, !isInvoked)
   666  }
   667  
   668  // TestMuxSvcOnReadHeartbeat test SvrTransHandler OnRead to process heartbeat
   669  func TestMuxSvrOnReadHeartbeat(t *testing.T) {
   670  	var isWriteBufFlushed atomic.Value
   671  	var isReaderBufReleased atomic.Value
   672  	var isInvoked atomic.Value
   673  
   674  	buf := netpoll.NewLinkBuffer(1024)
   675  	npconn := &MockNetpollConn{
   676  		ReaderFunc: func() (r netpoll.Reader) {
   677  			isReaderBufReleased.Store(1)
   678  			return buf
   679  		},
   680  		WriterFunc: func() (r netpoll.Writer) {
   681  			isWriteBufFlushed.Store(1)
   682  			return buf
   683  		},
   684  		Conn: mocks.Conn{
   685  			RemoteAddrFunc: func() (r net.Addr) {
   686  				return addr
   687  			},
   688  		},
   689  	}
   690  
   691  	var heartbeatFlag bool
   692  	body := "non-heartbeat process"
   693  	ctx := context.Background()
   694  	rpcInfo := newTestRpcInfo()
   695  	ctx = rpcinfo.NewCtxWithRPCInfo(ctx, rpcInfo)
   696  
   697  	// use newOpt cause we need to add heartbeat logic to EncodeFunc and DecodeFunc
   698  	newOpt := &remote.ServerOption{
   699  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
   700  			return rpcInfo
   701  		},
   702  		Codec: &MockCodec{
   703  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   704  				if heartbeatFlag {
   705  					if msg.MessageType() != remote.Heartbeat {
   706  						return errors.New("response is not of MessageType Heartbeat")
   707  					}
   708  					return nil
   709  				}
   710  				r := mockHeader(msg.RPCInfo().Invocation().SeqID(), body)
   711  				_, err := out.WriteBinary(r.Bytes())
   712  				return err
   713  			},
   714  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   715  				if heartbeatFlag {
   716  					msg.SetMessageType(remote.Heartbeat)
   717  					return nil
   718  				}
   719  				in.Skip(3 * codec.Size32)
   720  				_, err := in.ReadString(len(body))
   721  				return err
   722  			},
   723  		},
   724  		SvcSearchMap:     svcSearchMap,
   725  		TargetSvcInfo:    svcInfo,
   726  		TracerCtl:        &rpcinfo.TraceController{},
   727  		ReadWriteTimeout: rwTimeout,
   728  	}
   729  	svrTransHdlr, _ := NewSvrTransHandlerFactory().NewTransHandler(newOpt)
   730  
   731  	msg := &MockMessage{
   732  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   733  			return rpcInfo
   734  		},
   735  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
   736  			return &serviceinfo.ServiceInfo{
   737  				Methods: map[string]serviceinfo.MethodInfo{
   738  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
   739  				},
   740  			}
   741  		},
   742  	}
   743  
   744  	pool := &sync.Pool{}
   745  	muxSvrCon := newMuxSvrConn(npconn, pool)
   746  
   747  	var err error
   748  
   749  	ri := rpcinfo.GetRPCInfo(ctx)
   750  	test.Assert(t, ri != nil, ri)
   751  
   752  	ctx, err = svrTransHdlr.Write(ctx, muxSvrCon, msg)
   753  	test.Assert(t, ctx != nil, ctx)
   754  	test.Assert(t, err == nil, err)
   755  
   756  	time.Sleep(10 * time.Millisecond)
   757  	buf.Flush()
   758  	test.Assert(t, npconn.Reader().Len() > 0, npconn.Reader().Len())
   759  
   760  	ctx, err = svrTransHdlr.OnActive(ctx, muxSvrCon)
   761  	test.Assert(t, ctx != nil, ctx)
   762  	test.Assert(t, err == nil, err)
   763  	muxSvrConFromCtx, _ := ctx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   764  	test.Assert(t, muxSvrConFromCtx != nil)
   765  
   766  	pl := remote.NewTransPipeline(svrTransHdlr)
   767  	svrTransHdlr.SetPipeline(pl)
   768  
   769  	if setter, ok := svrTransHdlr.(remote.InvokeHandleFuncSetter); ok {
   770  		setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
   771  			isInvoked.Store(1)
   772  			return nil
   773  		})
   774  	}
   775  
   776  	// start the heartbeat processing
   777  	heartbeatFlag = true
   778  	err = svrTransHdlr.OnRead(ctx, npconn)
   779  	test.Assert(t, err == nil, err)
   780  	time.Sleep(50 * time.Millisecond)
   781  
   782  	test.Assert(t, isReaderBufReleased.Load() == 1)
   783  	test.Assert(t, isWriteBufFlushed.Load() == 1)
   784  	// InvokeHandleFunc has not been invoked
   785  	test.Assert(t, isInvoked.Load() == nil)
   786  }