github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/default_codec_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package codec 18 19 import ( 20 "context" 21 "encoding/binary" 22 "errors" 23 "testing" 24 25 "github.com/bytedance/mockey" 26 "github.com/golang/mock/gomock" 27 28 "github.com/cloudwego/kitex/internal/mocks" 29 mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" 30 "github.com/cloudwego/kitex/internal/test" 31 "github.com/cloudwego/kitex/pkg/remote" 32 "github.com/cloudwego/kitex/pkg/rpcinfo" 33 "github.com/cloudwego/kitex/pkg/serviceinfo" 34 "github.com/cloudwego/kitex/transport" 35 ) 36 37 func TestThriftProtocolCheck(t *testing.T) { 38 var req interface{} 39 var rbf remote.ByteBuffer 40 var ttheader bool 41 var flagBuf []byte 42 var ri rpcinfo.RPCInfo 43 var msg remote.Message 44 45 resetRIAndMSG := func() { 46 ri = rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) 47 msg = remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) 48 } 49 50 // 1. ttheader 51 resetRIAndMSG() 52 flagBuf = make([]byte, 8*2) 53 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 54 binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic) 55 binary.BigEndian.PutUint32(flagBuf[8:12], ThriftV1Magic) 56 ttheader = IsTTHeader(flagBuf) 57 test.Assert(t, ttheader) 58 if ttheader { 59 flagBuf = flagBuf[8:] 60 } 61 rbf = remote.NewReaderBuffer(flagBuf) 62 err := checkPayload(flagBuf, msg, rbf, ttheader, 10) 63 test.Assert(t, err == nil, err) 64 test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeader) 65 test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.TTHeader == transport.TTHeader) 66 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) 67 68 // 2. ttheader framed 69 resetRIAndMSG() 70 flagBuf = make([]byte, 8*2) 71 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 72 binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic) 73 binary.BigEndian.PutUint32(flagBuf[12:], ThriftV1Magic) 74 ttheader = IsTTHeader(flagBuf) 75 test.Assert(t, ttheader) 76 if ttheader { 77 flagBuf = flagBuf[8:] 78 } 79 rbf = remote.NewReaderBuffer(flagBuf) 80 err = checkPayload(flagBuf, msg, rbf, ttheader, 10) 81 test.Assert(t, err == nil, err) 82 test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeaderFramed) 83 test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.TTHeaderFramed == transport.TTHeaderFramed) 84 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) 85 86 // 3. thrift framed 87 resetRIAndMSG() 88 flagBuf = make([]byte, 8*2) 89 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 90 binary.BigEndian.PutUint32(flagBuf[4:8], ThriftV1Magic) 91 ttheader = IsTTHeader(flagBuf) 92 test.Assert(t, !ttheader) 93 rbf = remote.NewReaderBuffer(flagBuf) 94 err = checkPayload(flagBuf, msg, rbf, ttheader, 10) 95 test.Assert(t, err == nil, err) 96 err = checkPayload(flagBuf, msg, rbf, ttheader, 9) 97 test.Assert(t, err != nil, err) 98 test.Assert(t, msg.ProtocolInfo().TransProto == transport.Framed) 99 test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.Framed == transport.Framed) 100 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) 101 102 // 4. thrift pure payload 103 // resetRIAndMSG() // the logic below needs to check payload length set by the front case, so we don't reset ri 104 flagBuf = make([]byte, 8*2) 105 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 106 binary.BigEndian.PutUint32(flagBuf[0:4], ThriftV1Magic) 107 ttheader = IsTTHeader(flagBuf) 108 test.Assert(t, !ttheader) 109 rbf = remote.NewReaderBuffer(flagBuf) 110 err = checkPayload(flagBuf, msg, rbf, ttheader, 10) 111 test.Assert(t, err == nil, err) 112 err = checkPayload(flagBuf, msg, rbf, ttheader, 9) 113 test.Assert(t, err != nil, err) 114 test.Assert(t, msg.ProtocolInfo().TransProto == transport.PurePayload) 115 test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.PurePayload == transport.PurePayload) 116 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift) 117 } 118 119 func TestProtobufProtocolCheck(t *testing.T) { 120 var req interface{} 121 var rbf remote.ByteBuffer 122 var ttheader bool 123 var flagBuf []byte 124 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) 125 msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server) 126 127 // 1. ttheader framed 128 flagBuf = make([]byte, 8*2) 129 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 130 binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic) 131 binary.BigEndian.PutUint32(flagBuf[12:], ProtobufV1Magic) 132 ttheader = IsTTHeader(flagBuf) 133 test.Assert(t, ttheader) 134 if ttheader { 135 flagBuf = flagBuf[8:] 136 } 137 rbf = remote.NewReaderBuffer(flagBuf) 138 err := checkPayload(flagBuf, msg, rbf, ttheader, 10) 139 test.Assert(t, err == nil, err) 140 test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeaderFramed) 141 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf) 142 143 // 2. protobuf framed 144 flagBuf = make([]byte, 8*2) 145 binary.BigEndian.PutUint32(flagBuf, uint32(10)) 146 binary.BigEndian.PutUint32(flagBuf[4:8], ProtobufV1Magic) 147 ttheader = IsTTHeader(flagBuf) 148 test.Assert(t, !ttheader) 149 rbf = remote.NewReaderBuffer(flagBuf) 150 err = checkPayload(flagBuf, msg, rbf, ttheader, 10) 151 test.Assert(t, err == nil, err) 152 err = checkPayload(flagBuf, msg, rbf, ttheader, 9) 153 test.Assert(t, err != nil, err) 154 test.Assert(t, msg.ProtocolInfo().TransProto == transport.Framed) 155 test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf) 156 } 157 158 func TestDefaultCodec_Encode_Decode(t *testing.T) { 159 remote.PutPayloadCode(serviceinfo.Thrift, mpc) 160 161 dc := NewDefaultCodec() 162 ctx := context.Background() 163 intKVInfo := prepareIntKVInfo() 164 strKVInfo := prepareStrKVInfo() 165 sendMsg := initClientSendMsg(transport.TTHeader) 166 sendMsg.TransInfo().PutTransIntInfo(intKVInfo) 167 sendMsg.TransInfo().PutTransStrInfo(strKVInfo) 168 169 // test encode err 170 out := remote.NewReaderBuffer([]byte{}) 171 err := dc.Encode(ctx, sendMsg, out) 172 test.Assert(t, err != nil) 173 174 // encode 175 out = remote.NewWriterBuffer(256) 176 err = dc.Encode(ctx, sendMsg, out) 177 test.Assert(t, err == nil, err) 178 179 // decode 180 recvMsg := initServerRecvMsg() 181 buf, err := out.Bytes() 182 test.Assert(t, err == nil, err) 183 in := remote.NewReaderBuffer(buf) 184 err = dc.Decode(ctx, recvMsg, in) 185 test.Assert(t, err == nil, err) 186 187 intKVInfoRecv := recvMsg.TransInfo().TransIntInfo() 188 strKVInfoRecv := recvMsg.TransInfo().TransStrInfo() 189 test.DeepEqual(t, intKVInfoRecv, intKVInfo) 190 test.DeepEqual(t, strKVInfoRecv, strKVInfo) 191 test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID()) 192 } 193 194 func TestDefaultSizedCodec_Encode_Decode(t *testing.T) { 195 remote.PutPayloadCode(serviceinfo.Thrift, mpc) 196 197 smallDc := NewDefaultCodecWithSizeLimit(1) 198 largeDc := NewDefaultCodecWithSizeLimit(1024) 199 ctx := context.Background() 200 intKVInfo := prepareIntKVInfo() 201 strKVInfo := prepareStrKVInfo() 202 sendMsg := initClientSendMsg(transport.TTHeader) 203 sendMsg.TransInfo().PutTransIntInfo(intKVInfo) 204 sendMsg.TransInfo().PutTransStrInfo(strKVInfo) 205 206 // encode 207 smallOut := remote.NewWriterBuffer(256) 208 largeOut := remote.NewWriterBuffer(256) 209 err := smallDc.Encode(ctx, sendMsg, smallOut) 210 test.Assert(t, err != nil, err) 211 err = largeDc.Encode(ctx, sendMsg, largeOut) 212 test.Assert(t, err == nil, err) 213 214 // decode 215 recvMsg := initServerRecvMsg() 216 smallBuf, _ := smallOut.Bytes() 217 largeBuf, _ := largeOut.Bytes() 218 err = smallDc.Decode(ctx, recvMsg, remote.NewReaderBuffer(smallBuf)) 219 test.Assert(t, err != nil, err) 220 err = largeDc.Decode(ctx, recvMsg, remote.NewReaderBuffer(largeBuf)) 221 test.Assert(t, err == nil, err) 222 } 223 224 func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) { 225 var req interface{} 226 remote.PutPayloadCode(serviceinfo.Thrift, mpc) 227 remote.PutPayloadCode(serviceinfo.Protobuf, mpc) 228 ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats()) 229 codec := NewDefaultCodec() 230 231 // case 1: the payloadCodec of svcInfo is Protobuf, CodecType of message is Thrift 232 svcInfo := &serviceinfo.ServiceInfo{ 233 PayloadCodec: serviceinfo.Protobuf, 234 } 235 msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Server) 236 msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.TTHeader, CodecType: serviceinfo.Thrift}) 237 err := codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256)) 238 test.Assert(t, err == nil, err) 239 240 // case 2: the payloadCodec of svcInfo is Thrift, CodecType of message is Protobuf 241 svcInfo = &serviceinfo.ServiceInfo{ 242 PayloadCodec: serviceinfo.Thrift, 243 } 244 msg = remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Server) 245 msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.TTHeader, CodecType: serviceinfo.Protobuf}) 246 err = codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256)) 247 test.Assert(t, err != nil) 248 msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.Framed, CodecType: serviceinfo.Protobuf}) 249 err = codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256)) 250 test.Assert(t, err == nil) 251 } 252 253 var mpc remote.PayloadCodec = mockPayloadCodec{} 254 255 type mockPayloadCodec struct{} 256 257 func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error { 258 WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out) 259 WriteString(message.RPCInfo().Invocation().MethodName(), out) 260 WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out) 261 return nil 262 } 263 264 func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 265 magicAndMsgType, err := ReadUint32(in) 266 if err != nil { 267 return err 268 } 269 if magicAndMsgType&MagicMask != ThriftV1Magic { 270 return errors.New("bad version") 271 } 272 msgType := magicAndMsgType & FrontMask 273 if err := UpdateMsgType(msgType, message); err != nil { 274 return err 275 } 276 277 methodName, _, err := ReadString(in) 278 if err != nil { 279 return err 280 } 281 if err = SetOrCheckMethodName(methodName, message); err != nil && msgType != uint32(remote.Exception) { 282 return err 283 } 284 seqID, err := ReadUint32(in) 285 if err != nil { 286 return err 287 } 288 if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) { 289 return err 290 } 291 return nil 292 } 293 294 func (m mockPayloadCodec) Name() string { 295 return "mock" 296 } 297 298 func TestCornerCase(t *testing.T) { 299 ctrl := gomock.NewController(t) 300 defer ctrl.Finish() 301 302 sendMsg := initClientSendMsg(transport.TTHeader) 303 sendMsg.SetProtocolInfo(remote.NewProtocolInfo(transport.Framed, serviceinfo.Thrift)) 304 305 buffer := mocksremote.NewMockByteBuffer(ctrl) 306 buffer.EXPECT().MallocLen().Return(1024).AnyTimes() 307 buffer.EXPECT().Malloc(gomock.Any()).Return(nil, errors.New("error malloc")).AnyTimes() 308 err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer) 309 test.Assert(t, err.Error() == "error malloc") 310 311 mockey.PatchConvey("", t, func() { 312 mockey.Mock(remote.GetPayloadCodec).Return(nil, errors.New("err get payload codec")).Build() 313 buffer = mocksremote.NewMockByteBuffer(ctrl) 314 buffer.EXPECT().MallocLen().Return(1024).AnyTimes() 315 buffer.EXPECT().Malloc(gomock.Any()).Return(nil, nil).AnyTimes() 316 err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer) 317 test.Assert(t, err.Error() == "err get payload codec") 318 }) 319 }