github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpoll/trans_server_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package netpoll 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "os" 24 "sync" 25 "testing" 26 "time" 27 28 "github.com/golang/mock/gomock" 29 30 "github.com/cloudwego/kitex/internal/mocks" 31 mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" 32 "github.com/cloudwego/kitex/internal/test" 33 "github.com/cloudwego/kitex/pkg/remote" 34 "github.com/cloudwego/kitex/pkg/rpcinfo" 35 "github.com/cloudwego/kitex/pkg/serviceinfo" 36 "github.com/cloudwego/kitex/pkg/utils" 37 ) 38 39 var ( 40 svrTransHdlr remote.ServerTransHandler 41 rwTimeout = time.Second 42 addrStr = "test addr" 43 addr = utils.NewNetAddr("tcp", addrStr) 44 method = "mock" 45 transSvr *transServer 46 svrOpt *remote.ServerOption 47 ) 48 49 func TestMain(m *testing.M) { 50 svcInfo := mocks.ServiceInfo() 51 svrOpt = &remote.ServerOption{ 52 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 53 fromInfo := rpcinfo.EmptyEndpointInfo() 54 rpcCfg := rpcinfo.NewRPCConfig() 55 mCfg := rpcinfo.AsMutableRPCConfig(rpcCfg) 56 mCfg.SetReadWriteTimeout(rwTimeout) 57 ink := rpcinfo.NewInvocation("", method) 58 rpcStat := rpcinfo.NewRPCStats() 59 nri := rpcinfo.NewRPCInfo(fromInfo, nil, ink, rpcCfg, rpcStat) 60 rpcinfo.AsMutableEndpointInfo(nri.From()).SetAddress(addr) 61 return nri 62 }, 63 Codec: &MockCodec{ 64 EncodeFunc: nil, 65 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 66 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 67 return nil 68 }, 69 }, 70 SvcSearchMap: map[string]*serviceinfo.ServiceInfo{ 71 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, 72 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, 73 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, 74 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, 75 mocks.MockMethod: svcInfo, 76 mocks.MockExceptionMethod: svcInfo, 77 mocks.MockErrorMethod: svcInfo, 78 mocks.MockOnewayMethod: svcInfo, 79 }, 80 TargetSvcInfo: svcInfo, 81 TracerCtl: &rpcinfo.TraceController{}, 82 } 83 svrTransHdlr, _ = newSvrTransHandler(svrOpt) 84 transSvr = NewTransServerFactory().NewTransServer(svrOpt, svrTransHdlr).(*transServer) 85 86 os.Exit(m.Run()) 87 } 88 89 // TestCreateListener test trans_server CreateListener success 90 func TestCreateListener(t *testing.T) { 91 // tcp init 92 addrStr := "127.0.0.1:9091" 93 addr = utils.NewNetAddr("tcp", addrStr) 94 95 // test 96 ln, err := transSvr.CreateListener(addr) 97 test.Assert(t, err == nil, err) 98 test.Assert(t, ln.Addr().String() == addrStr) 99 ln.Close() 100 101 // uds init 102 addrStr = "server.addr" 103 addr, err = net.ResolveUnixAddr("unix", addrStr) 104 test.Assert(t, err == nil, err) 105 106 // test 107 ln, err = transSvr.CreateListener(addr) 108 test.Assert(t, err == nil, err) 109 test.Assert(t, ln.Addr().String() == addrStr) 110 ln.Close() 111 } 112 113 // TestBootStrap test trans_server BootstrapServer success 114 func TestBootStrap(t *testing.T) { 115 // tcp init 116 addrStr := "127.0.0.1:9092" 117 addr = utils.NewNetAddr("tcp", addrStr) 118 119 // test 120 ln, err := transSvr.CreateListener(addr) 121 test.Assert(t, err == nil, err) 122 test.Assert(t, ln.Addr().String() == addrStr) 123 124 var wg sync.WaitGroup 125 wg.Add(1) 126 go func() { 127 err = transSvr.BootstrapServer(ln) 128 test.Assert(t, err == nil, err) 129 wg.Done() 130 }() 131 time.Sleep(10 * time.Millisecond) 132 133 transSvr.Shutdown() 134 wg.Wait() 135 } 136 137 // TestOnConnActive test trans_server onConnActive success 138 func TestConnOnActive(t *testing.T) { 139 // 1. prepare mock data 140 conn := &MockNetpollConn{ 141 SetReadTimeoutFunc: func(timeout time.Duration) (e error) { 142 return nil 143 }, 144 Conn: mocks.Conn{ 145 RemoteAddrFunc: func() (r net.Addr) { 146 return addr 147 }, 148 }, 149 } 150 151 // 2. test 152 connCount := 100 153 for i := 0; i < connCount; i++ { 154 transSvr.onConnActive(conn) 155 } 156 ctx := context.Background() 157 158 currConnCount := transSvr.ConnCount() 159 test.Assert(t, currConnCount.Value() == connCount) 160 161 for i := 0; i < connCount; i++ { 162 transSvr.onConnInactive(ctx, conn) 163 } 164 165 currConnCount = transSvr.ConnCount() 166 test.Assert(t, currConnCount.Value() == 0) 167 } 168 169 // TestOnConnActivePanic test panic recover when panic happen in OnActive 170 func TestConnOnActiveAndOnInactivePanic(t *testing.T) { 171 ctrl := gomock.NewController(t) 172 defer func() { 173 ctrl.Finish() 174 }() 175 176 inboundHandler := mocksremote.NewMockInboundHandler(ctrl) 177 transPl := remote.NewTransPipeline(svrTransHdlr) 178 transPl.AddInboundHandler(inboundHandler) 179 transSvrWithPl := NewTransServerFactory().NewTransServer(svrOpt, transPl).(*transServer) 180 conn := &MockNetpollConn{ 181 Conn: mocks.Conn{ 182 RemoteAddrFunc: func() (r net.Addr) { 183 return addr 184 }, 185 }, 186 } 187 188 // test1: recover OnActive panic 189 inboundHandler.EXPECT().OnActive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) { 190 panic("mock panic") 191 }) 192 transSvrWithPl.onConnActive(conn) 193 194 // test2: recover OnInactive panic 195 inboundHandler.EXPECT().OnInactive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) { 196 panic("mock panic") 197 }) 198 transSvrWithPl.onConnInactive(context.Background(), conn) 199 } 200 201 // TestOnConnRead test trans_server onConnRead success 202 func TestConnOnRead(t *testing.T) { 203 // 1. prepare mock data 204 var isClosed bool 205 conn := &MockNetpollConn{ 206 Conn: mocks.Conn{ 207 RemoteAddrFunc: func() (r net.Addr) { 208 return addr 209 }, 210 CloseFunc: func() (e error) { 211 isClosed = true 212 return nil 213 }, 214 }, 215 } 216 mockErr := errors.New("mock error") 217 transSvr.transHdlr = &mocks.MockSvrTransHandler{ 218 OnReadFunc: func(ctx context.Context, conn net.Conn) error { 219 return mockErr 220 }, 221 Opt: transSvr.opt, 222 } 223 224 // 2. test 225 err := transSvr.onConnRead(context.Background(), conn) 226 test.Assert(t, err == nil, err) 227 test.Assert(t, isClosed) 228 }