github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/detection/server_handler_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package detection 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "net" 24 "testing" 25 26 "github.com/golang/mock/gomock" 27 28 "github.com/cloudwego/kitex/internal/mocks" 29 mocksklog "github.com/cloudwego/kitex/internal/mocks/klog" 30 npmocks "github.com/cloudwego/kitex/internal/mocks/netpoll" 31 remote_mocks "github.com/cloudwego/kitex/internal/mocks/remote" 32 "github.com/cloudwego/kitex/internal/test" 33 "github.com/cloudwego/kitex/pkg/klog" 34 "github.com/cloudwego/kitex/pkg/remote" 35 "github.com/cloudwego/kitex/pkg/remote/codec" 36 "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" 37 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" 38 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" 39 "github.com/cloudwego/kitex/pkg/serviceinfo" 40 "github.com/cloudwego/kitex/pkg/utils" 41 ) 42 43 var ( 44 prefaceReadAtMost = func() int { 45 // min(len(ClientPreface), len(flagBuf)) 46 // len(flagBuf) = 2 * codec.Size32 47 if 2*codec.Size32 < grpc.ClientPrefaceLen { 48 return 2 * codec.Size32 49 } 50 return grpc.ClientPrefaceLen 51 }() 52 svcInfo = mocks.ServiceInfo() 53 svcSearchMap = map[string]*serviceinfo.ServiceInfo{ 54 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, 55 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, 56 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, 57 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, 58 mocks.MockMethod: svcInfo, 59 mocks.MockExceptionMethod: svcInfo, 60 mocks.MockErrorMethod: svcInfo, 61 mocks.MockOnewayMethod: svcInfo, 62 } 63 ) 64 65 func TestServerHandlerCall(t *testing.T) { 66 transHdler, _ := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ 67 SvcSearchMap: svcSearchMap, 68 TargetSvcInfo: svcInfo, 69 }) 70 71 ctrl := gomock.NewController(t) 72 defer ctrl.Finish() 73 74 npConn := npmocks.NewMockConnection(ctrl) 75 npReader := npmocks.NewMockReader(ctrl) 76 hdl := remote_mocks.NewMockServerTransHandler(ctrl) 77 78 errOnActive := errors.New("mock on active error") 79 errOnRead := errors.New("mock on read error") 80 81 triggerReadErr := false 82 triggerActiveErr := false 83 84 hdl.EXPECT().OnActive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) { 85 if triggerActiveErr { 86 return ctx, errOnActive 87 } 88 return ctx, nil 89 }).AnyTimes() 90 hdl.EXPECT().OnRead(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) error { 91 if triggerReadErr { 92 return errOnRead 93 } 94 return nil 95 }).AnyTimes() 96 hdl.EXPECT().OnInactive(gomock.Any(), gomock.Any()).AnyTimes() 97 hdl.EXPECT().OnError(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, err error, conn net.Conn) { 98 if conn != nil { 99 klog.CtxErrorf(ctx, "KITEX: processing error, remoteAddr=%v, error=%s", conn.RemoteAddr(), err.Error()) 100 } else { 101 klog.CtxErrorf(ctx, "KITEX: processing error, error=%s", err.Error()) 102 } 103 }).AnyTimes() 104 105 npReader.EXPECT().Peek(prefaceReadAtMost).Return([]byte("connection prefix bytes"), nil).AnyTimes() 106 npConn.EXPECT().Reader().Return(npReader).AnyTimes() 107 npConn.EXPECT().RemoteAddr().Return(nil).AnyTimes() 108 109 transHdler.(*svrTransHandler).defaultHandler = hdl 110 111 // case1 successful call: onActive() and onRead() all success 112 triggerActiveErr = false 113 triggerReadErr = false 114 err := mockCall(transHdler, npConn) 115 test.Assert(t, err == nil, err) 116 117 // case2 onActive failed: onActive() err and close conn 118 triggerActiveErr = true 119 triggerReadErr = false 120 err = mockCall(transHdler, npConn) 121 test.Assert(t, err == errOnActive, err) 122 123 // case3 onRead failed: onRead() err and close conn 124 triggerActiveErr = false 125 triggerReadErr = true 126 err = mockCall(transHdler, npConn) 127 test.Assert(t, err == errOnRead, err) 128 } 129 130 func TestOnError(t *testing.T) { 131 ctrl := gomock.NewController(t) 132 defer func() { 133 klog.SetLogger(klog.DefaultLogger()) 134 ctrl.Finish() 135 }() 136 transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ 137 SvcSearchMap: svcSearchMap, 138 TargetSvcInfo: svcInfo, 139 }) 140 test.Assert(t, err == nil) 141 142 mocklogger := mocksklog.NewMockFullLogger(ctrl) 143 klog.SetLogger(mocklogger) 144 145 var errMsg string 146 mocklogger.EXPECT().CtxErrorf(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(ctx context.Context, format string, v ...interface{}) { 147 errMsg = fmt.Sprintf(format, v...) 148 }) 149 transHdler.OnError(context.Background(), errors.New("mock error"), nil) 150 test.Assert(t, errMsg == "KITEX: processing error, error=mock error", errMsg) 151 152 conn := &mocks.Conn{ 153 RemoteAddrFunc: func() (r net.Addr) { 154 return utils.NewNetAddr("mock", "mock") 155 }, 156 } 157 mocklogger.EXPECT().CtxErrorf(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(ctx context.Context, format string, v ...interface{}) { 158 errMsg = fmt.Sprintf(format, v...) 159 }) 160 transHdler.OnError(context.Background(), errors.New("mock error"), conn) 161 test.Assert(t, errMsg == "KITEX: processing error, remoteAddr=mock, error=mock error", errMsg) 162 } 163 164 // TestOnInactive covers onInactive() codes to check panic 165 func TestOnInactive(t *testing.T) { 166 transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{ 167 SvcSearchMap: svcSearchMap, 168 TargetSvcInfo: svcInfo, 169 }) 170 test.Assert(t, err == nil) 171 172 conn := &mocks.Conn{ 173 RemoteAddrFunc: func() (r net.Addr) { 174 return utils.NewNetAddr("mock", "mock") 175 }, 176 } 177 178 // case1 test noopHandler onInactive() 179 transHdler.OnInactive(context.Background(), conn) 180 181 // mock a ctx and set handlerKey 182 subHandler := &mocks.MockSvrTransHandler{} 183 subHandlerCtx := context.WithValue( 184 context.Background(), 185 handlerKey{}, 186 &handlerWrapper{ 187 handler: subHandler, 188 }, 189 ) 190 191 ctx := context.WithValue( 192 context.Background(), 193 handlerKey{}, 194 &handlerWrapper{ 195 ctx: subHandlerCtx, 196 }, 197 ) 198 // case2 test subHandler onInactive() 199 transHdler.OnInactive(ctx, conn) 200 } 201 202 // mockCall mocks how detection transHdlr processing with incoming requests 203 func mockCall(transHdlr remote.ServerTransHandler, conn net.Conn) (err error) { 204 ctx := context.Background() 205 // do onConnActive 206 ctxWithHandler, err := transHdlr.OnActive(ctx, conn) 207 // onActive failed, such as connections limitation 208 if err != nil { 209 transHdlr.OnError(ctx, err, conn) 210 transHdlr.OnInactive(ctx, conn) 211 return 212 } 213 // do onConnRead 214 err = transHdlr.OnRead(ctxWithHandler, conn) 215 if err != nil { 216 transHdlr.OnError(ctxWithHandler, err, conn) 217 transHdlr.OnInactive(ctxWithHandler, conn) 218 return 219 } 220 // do onConnInactive 221 transHdlr.OnInactive(ctxWithHandler, conn) 222 return 223 }