github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpoll/trans_server_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpoll
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"os"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/golang/mock/gomock"
    29  
    30  	"github.com/cloudwego/kitex/internal/mocks"
    31  	mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
    32  	"github.com/cloudwego/kitex/internal/test"
    33  	"github.com/cloudwego/kitex/pkg/remote"
    34  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    35  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    36  	"github.com/cloudwego/kitex/pkg/utils"
    37  )
    38  
    39  var (
    40  	svrTransHdlr remote.ServerTransHandler
    41  	rwTimeout    = time.Second
    42  	addrStr      = "test addr"
    43  	addr         = utils.NewNetAddr("tcp", addrStr)
    44  	method       = "mock"
    45  	transSvr     *transServer
    46  	svrOpt       *remote.ServerOption
    47  )
    48  
    49  func TestMain(m *testing.M) {
    50  	svcInfo := mocks.ServiceInfo()
    51  	svrOpt = &remote.ServerOption{
    52  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
    53  			fromInfo := rpcinfo.EmptyEndpointInfo()
    54  			rpcCfg := rpcinfo.NewRPCConfig()
    55  			mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg)
    56  			mCfg.SetReadWriteTimeout(rwTimeout)
    57  			ink := rpcinfo.NewInvocation("", method)
    58  			rpcStat := rpcinfo.NewRPCStats()
    59  			nri := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat)
    60  			rpcinfo.AsMutableEndpointInfo(nri.From()).SetAddress(addr)
    61  			return nri
    62  		},
    63  		Codec: &MockCodec{
    64  			EncodeFunc: nil,
    65  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
    66  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
    67  				return nil
    68  			},
    69  		},
    70  		SvcSearchMap: map[string]*serviceinfo.ServiceInfo{
    71  			remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod):          svcInfo,
    72  			remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
    73  			remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod):     svcInfo,
    74  			remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod):    svcInfo,
    75  			mocks.MockMethod:          svcInfo,
    76  			mocks.MockExceptionMethod: svcInfo,
    77  			mocks.MockErrorMethod:     svcInfo,
    78  			mocks.MockOnewayMethod:    svcInfo,
    79  		},
    80  		TargetSvcInfo: svcInfo,
    81  		TracerCtl:     &rpcinfo.TraceController{},
    82  	}
    83  	svrTransHdlr, _ = newSvrTransHandler(svrOpt)
    84  	transSvr = NewTransServerFactory().NewTransServer(svrOpt, svrTransHdlr).(*transServer)
    85  
    86  	os.Exit(m.Run())
    87  }
    88  
    89  // TestCreateListener test trans_server CreateListener success
    90  func TestCreateListener(t *testing.T) {
    91  	// tcp init
    92  	addrStr := "127.0.0.1:9091"
    93  	addr = utils.NewNetAddr("tcp", addrStr)
    94  
    95  	// test
    96  	ln, err := transSvr.CreateListener(addr)
    97  	test.Assert(t, err == nil, err)
    98  	test.Assert(t, ln.Addr().String() == addrStr)
    99  	ln.Close()
   100  
   101  	// uds init
   102  	addrStr = "server.addr"
   103  	addr, err = net.ResolveUnixAddr("unix", addrStr)
   104  	test.Assert(t, err == nil, err)
   105  
   106  	// test
   107  	ln, err = transSvr.CreateListener(addr)
   108  	test.Assert(t, err == nil, err)
   109  	test.Assert(t, ln.Addr().String() == addrStr)
   110  	ln.Close()
   111  }
   112  
   113  // TestBootStrap test trans_server BootstrapServer success
   114  func TestBootStrap(t *testing.T) {
   115  	// tcp init
   116  	addrStr := "127.0.0.1:9092"
   117  	addr = utils.NewNetAddr("tcp", addrStr)
   118  
   119  	// test
   120  	ln, err := transSvr.CreateListener(addr)
   121  	test.Assert(t, err == nil, err)
   122  	test.Assert(t, ln.Addr().String() == addrStr)
   123  
   124  	var wg sync.WaitGroup
   125  	wg.Add(1)
   126  	go func() {
   127  		err = transSvr.BootstrapServer(ln)
   128  		test.Assert(t, err == nil, err)
   129  		wg.Done()
   130  	}()
   131  	time.Sleep(10 * time.Millisecond)
   132  
   133  	transSvr.Shutdown()
   134  	wg.Wait()
   135  }
   136  
   137  // TestOnConnActive test trans_server onConnActive success
   138  func TestConnOnActive(t *testing.T) {
   139  	// 1. prepare mock data
   140  	conn := &MockNetpollConn{
   141  		SetReadTimeoutFunc: func(timeout time.Duration) (e error) {
   142  			return nil
   143  		},
   144  		Conn: mocks.Conn{
   145  			RemoteAddrFunc: func() (r net.Addr) {
   146  				return addr
   147  			},
   148  		},
   149  	}
   150  
   151  	// 2. test
   152  	connCount := 100
   153  	for i := 0; i < connCount; i++ {
   154  		transSvr.onConnActive(conn)
   155  	}
   156  	ctx := context.Background()
   157  
   158  	currConnCount := transSvr.ConnCount()
   159  	test.Assert(t, currConnCount.Value() == connCount)
   160  
   161  	for i := 0; i < connCount; i++ {
   162  		transSvr.onConnInactive(ctx, conn)
   163  	}
   164  
   165  	currConnCount = transSvr.ConnCount()
   166  	test.Assert(t, currConnCount.Value() == 0)
   167  }
   168  
   169  // TestOnConnActivePanic test panic recover when panic happen in OnActive
   170  func TestConnOnActiveAndOnInactivePanic(t *testing.T) {
   171  	ctrl := gomock.NewController(t)
   172  	defer func() {
   173  		ctrl.Finish()
   174  	}()
   175  
   176  	inboundHandler := mocksremote.NewMockInboundHandler(ctrl)
   177  	transPl := remote.NewTransPipeline(svrTransHdlr)
   178  	transPl.AddInboundHandler(inboundHandler)
   179  	transSvrWithPl := NewTransServerFactory().NewTransServer(svrOpt, transPl).(*transServer)
   180  	conn := &MockNetpollConn{
   181  		Conn: mocks.Conn{
   182  			RemoteAddrFunc: func() (r net.Addr) {
   183  				return addr
   184  			},
   185  		},
   186  	}
   187  
   188  	// test1: recover OnActive panic
   189  	inboundHandler.EXPECT().OnActive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) {
   190  		panic("mock panic")
   191  	})
   192  	transSvrWithPl.onConnActive(conn)
   193  
   194  	// test2: recover OnInactive panic
   195  	inboundHandler.EXPECT().OnInactive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) {
   196  		panic("mock panic")
   197  	})
   198  	transSvrWithPl.onConnInactive(context.Background(), conn)
   199  }
   200  
   201  // TestOnConnRead test trans_server onConnRead success
   202  func TestConnOnRead(t *testing.T) {
   203  	// 1. prepare mock data
   204  	var isClosed bool
   205  	conn := &MockNetpollConn{
   206  		Conn: mocks.Conn{
   207  			RemoteAddrFunc: func() (r net.Addr) {
   208  				return addr
   209  			},
   210  			CloseFunc: func() (e error) {
   211  				isClosed = true
   212  				return nil
   213  			},
   214  		},
   215  	}
   216  	mockErr := errors.New("mock error")
   217  	transSvr.transHdlr = &mocks.MockSvrTransHandler{
   218  		OnReadFunc: func(ctx context.Context, conn net.Conn) error {
   219  			return mockErr
   220  		},
   221  		Opt: transSvr.opt,
   222  	}
   223  
   224  	// 2. test
   225  	err := transSvr.onConnRead(context.Background(), conn)
   226  	test.Assert(t, err == nil, err)
   227  	test.Assert(t, isClosed)
   228  }