github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/default_server_handler_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package trans 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "testing" 24 25 "github.com/golang/mock/gomock" 26 27 "github.com/cloudwego/kitex/internal/mocks" 28 "github.com/cloudwego/kitex/internal/mocks/stats" 29 "github.com/cloudwego/kitex/internal/test" 30 "github.com/cloudwego/kitex/pkg/kerrors" 31 "github.com/cloudwego/kitex/pkg/remote" 32 "github.com/cloudwego/kitex/pkg/rpcinfo" 33 "github.com/cloudwego/kitex/pkg/serviceinfo" 34 ) 35 36 var ( 37 svcInfo = mocks.ServiceInfo() 38 svcSearchMap = map[string]*serviceinfo.ServiceInfo{ 39 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod): svcInfo, 40 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo, 41 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod): svcInfo, 42 remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod): svcInfo, 43 mocks.MockMethod: svcInfo, 44 mocks.MockExceptionMethod: svcInfo, 45 mocks.MockErrorMethod: svcInfo, 46 mocks.MockOnewayMethod: svcInfo, 47 } 48 ) 49 50 func TestDefaultSvrTransHandler(t *testing.T) { 51 buf := remote.NewReaderWriterBuffer(1024) 52 ext := &MockExtension{ 53 NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 54 return buf 55 }, 56 NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 57 return buf 58 }, 59 } 60 61 tagEncode, tagDecode := 0, 0 62 opt := &remote.ServerOption{ 63 Codec: &MockCodec{ 64 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 65 tagEncode++ 66 test.Assert(t, out == buf) 67 return nil 68 }, 69 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 70 tagDecode++ 71 test.Assert(t, in == buf) 72 return nil 73 }, 74 }, 75 SvcSearchMap: svcSearchMap, 76 TargetSvcInfo: svcInfo, 77 } 78 79 handler, err := NewDefaultSvrTransHandler(opt, ext) 80 test.Assert(t, err == nil) 81 82 ctx := context.Background() 83 conn := &mocks.Conn{} 84 msg := &MockMessage{ 85 RPCInfoFunc: func() rpcinfo.RPCInfo { 86 return newMockRPCInfo() 87 }, 88 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 89 return &serviceinfo.ServiceInfo{ 90 Methods: map[string]serviceinfo.MethodInfo{ 91 "method": serviceinfo.NewMethodInfo(nil, nil, nil, false), 92 }, 93 } 94 }, 95 } 96 ctx, err = handler.Write(ctx, conn, msg) 97 test.Assert(t, ctx != nil, ctx) 98 test.Assert(t, err == nil, err) 99 test.Assert(t, tagEncode == 1, tagEncode) 100 test.Assert(t, tagDecode == 0, tagDecode) 101 102 ctx, err = handler.Read(ctx, conn, msg) 103 test.Assert(t, ctx != nil, ctx) 104 test.Assert(t, err == nil, err) 105 test.Assert(t, tagEncode == 1, tagEncode) 106 test.Assert(t, tagDecode == 1, tagDecode) 107 } 108 109 func TestSvrTransHandlerBizError(t *testing.T) { 110 ctrl := gomock.NewController(t) 111 defer ctrl.Finish() 112 113 mockTracer := stats.NewMockTracer(ctrl) 114 mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes() 115 mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { 116 err := rpcinfo.GetRPCInfo(ctx).Stats().Error() 117 test.Assert(t, err != nil) 118 }).AnyTimes() 119 120 buf := remote.NewReaderWriterBuffer(1024) 121 ext := &MockExtension{ 122 NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 123 return buf 124 }, 125 NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 126 return buf 127 }, 128 } 129 130 tracerCtl := &rpcinfo.TraceController{} 131 tracerCtl.Append(mockTracer) 132 opt := &remote.ServerOption{ 133 Codec: &MockCodec{ 134 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 135 return nil 136 }, 137 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 138 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 139 return nil 140 }, 141 }, 142 SvcSearchMap: svcSearchMap, 143 TargetSvcInfo: svcInfo, 144 TracerCtl: tracerCtl, 145 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 146 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) 147 return ri 148 }, 149 } 150 ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}), 151 rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats()) 152 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 153 154 svrHandler, err := NewDefaultSvrTransHandler(opt, ext) 155 pl := remote.NewTransPipeline(svrHandler) 156 svrHandler.SetPipeline(pl) 157 if setter, ok := svrHandler.(remote.InvokeHandleFuncSetter); ok { 158 setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { 159 return kerrors.ErrBiz.WithCause(errors.New("mock")) 160 }) 161 } 162 test.Assert(t, err == nil) 163 err = svrHandler.OnRead(ctx, &mocks.Conn{}) 164 test.Assert(t, err == nil) 165 } 166 167 func TestSvrTransHandlerReadErr(t *testing.T) { 168 ctrl := gomock.NewController(t) 169 defer ctrl.Finish() 170 171 mockTracer := stats.NewMockTracer(ctrl) 172 mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes() 173 mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { 174 err := rpcinfo.GetRPCInfo(ctx).Stats().Error() 175 test.Assert(t, err != nil) 176 }).AnyTimes() 177 178 buf := remote.NewReaderWriterBuffer(1024) 179 ext := &MockExtension{ 180 NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 181 return buf 182 }, 183 NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 184 return buf 185 }, 186 } 187 188 mockErr := errors.New("mock") 189 tracerCtl := &rpcinfo.TraceController{} 190 tracerCtl.Append(mockTracer) 191 opt := &remote.ServerOption{ 192 Codec: &MockCodec{ 193 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 194 return nil 195 }, 196 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 197 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 198 return mockErr 199 }, 200 }, 201 SvcSearchMap: svcSearchMap, 202 TargetSvcInfo: svcInfo, 203 TracerCtl: tracerCtl, 204 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 205 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) 206 return ri 207 }, 208 } 209 ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}), 210 rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats()) 211 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 212 213 svrHandler, err := NewDefaultSvrTransHandler(opt, ext) 214 test.Assert(t, err == nil) 215 pl := remote.NewTransPipeline(svrHandler) 216 svrHandler.SetPipeline(pl) 217 err = svrHandler.OnRead(ctx, &mocks.Conn{}) 218 test.Assert(t, err != nil) 219 test.Assert(t, errors.Is(err, mockErr)) 220 } 221 222 func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) { 223 ctrl := gomock.NewController(t) 224 defer ctrl.Finish() 225 226 mockTracer := stats.NewMockTracer(ctrl) 227 mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes() 228 mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) { 229 err := rpcinfo.GetRPCInfo(ctx).Stats().Error() 230 test.Assert(t, err == nil) 231 }).AnyTimes() 232 233 buf := remote.NewReaderWriterBuffer(1024) 234 ext := &MockExtension{ 235 NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 236 return buf 237 }, 238 NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer { 239 return buf 240 }, 241 } 242 243 tracerCtl := &rpcinfo.TraceController{} 244 tracerCtl.Append(mockTracer) 245 opt := &remote.ServerOption{ 246 Codec: &MockCodec{ 247 EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { 248 if msg.MessageType() != remote.Heartbeat { 249 return errors.New("response is not of MessageType Heartbeat") 250 } 251 return nil 252 }, 253 DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error { 254 msg.SetMessageType(remote.Heartbeat) 255 msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod) 256 return nil 257 }, 258 }, 259 SvcSearchMap: svcSearchMap, 260 TargetSvcInfo: svcInfo, 261 TracerCtl: tracerCtl, 262 InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo { 263 rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr) 264 return ri 265 }, 266 } 267 ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}), 268 rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats()) 269 ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 270 271 svrHandler, err := NewDefaultSvrTransHandler(opt, ext) 272 test.Assert(t, err == nil) 273 pl := remote.NewTransPipeline(svrHandler) 274 svrHandler.SetPipeline(pl) 275 err = svrHandler.OnRead(ctx, &mocks.Conn{}) 276 test.Assert(t, err == nil) 277 }