github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/default_codec_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package codec
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"errors"
    23  	"testing"
    24  
    25  	"github.com/bytedance/mockey"
    26  	"github.com/golang/mock/gomock"
    27  
    28  	"github.com/cloudwego/kitex/internal/mocks"
    29  	mocksremote "github.com/cloudwego/kitex/internal/mocks/remote"
    30  	"github.com/cloudwego/kitex/internal/test"
    31  	"github.com/cloudwego/kitex/pkg/remote"
    32  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    33  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    34  	"github.com/cloudwego/kitex/transport"
    35  )
    36  
    37  func TestThriftProtocolCheck(t *testing.T) {
    38  	var req interface{}
    39  	var rbf remote.ByteBuffer
    40  	var ttheader bool
    41  	var flagBuf []byte
    42  	var ri rpcinfo.RPCInfo
    43  	var msg remote.Message
    44  
    45  	resetRIAndMSG := func() {
    46  		ri = rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
    47  		msg = remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server)
    48  	}
    49  
    50  	// 1. ttheader
    51  	resetRIAndMSG()
    52  	flagBuf = make([]byte, 8*2)
    53  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
    54  	binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic)
    55  	binary.BigEndian.PutUint32(flagBuf[8:12], ThriftV1Magic)
    56  	ttheader = IsTTHeader(flagBuf)
    57  	test.Assert(t, ttheader)
    58  	if ttheader {
    59  		flagBuf = flagBuf[8:]
    60  	}
    61  	rbf = remote.NewReaderBuffer(flagBuf)
    62  	err := checkPayload(flagBuf, msg, rbf, ttheader, 10)
    63  	test.Assert(t, err == nil, err)
    64  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeader)
    65  	test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.TTHeader == transport.TTHeader)
    66  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift)
    67  
    68  	// 2. ttheader framed
    69  	resetRIAndMSG()
    70  	flagBuf = make([]byte, 8*2)
    71  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
    72  	binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic)
    73  	binary.BigEndian.PutUint32(flagBuf[12:], ThriftV1Magic)
    74  	ttheader = IsTTHeader(flagBuf)
    75  	test.Assert(t, ttheader)
    76  	if ttheader {
    77  		flagBuf = flagBuf[8:]
    78  	}
    79  	rbf = remote.NewReaderBuffer(flagBuf)
    80  	err = checkPayload(flagBuf, msg, rbf, ttheader, 10)
    81  	test.Assert(t, err == nil, err)
    82  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeaderFramed)
    83  	test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.TTHeaderFramed == transport.TTHeaderFramed)
    84  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift)
    85  
    86  	// 3. thrift framed
    87  	resetRIAndMSG()
    88  	flagBuf = make([]byte, 8*2)
    89  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
    90  	binary.BigEndian.PutUint32(flagBuf[4:8], ThriftV1Magic)
    91  	ttheader = IsTTHeader(flagBuf)
    92  	test.Assert(t, !ttheader)
    93  	rbf = remote.NewReaderBuffer(flagBuf)
    94  	err = checkPayload(flagBuf, msg, rbf, ttheader, 10)
    95  	test.Assert(t, err == nil, err)
    96  	err = checkPayload(flagBuf, msg, rbf, ttheader, 9)
    97  	test.Assert(t, err != nil, err)
    98  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.Framed)
    99  	test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.Framed == transport.Framed)
   100  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift)
   101  
   102  	// 4. thrift pure payload
   103  	// resetRIAndMSG() // the logic below needs to check payload length set by the front case, so we don't reset ri
   104  	flagBuf = make([]byte, 8*2)
   105  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
   106  	binary.BigEndian.PutUint32(flagBuf[0:4], ThriftV1Magic)
   107  	ttheader = IsTTHeader(flagBuf)
   108  	test.Assert(t, !ttheader)
   109  	rbf = remote.NewReaderBuffer(flagBuf)
   110  	err = checkPayload(flagBuf, msg, rbf, ttheader, 10)
   111  	test.Assert(t, err == nil, err)
   112  	err = checkPayload(flagBuf, msg, rbf, ttheader, 9)
   113  	test.Assert(t, err != nil, err)
   114  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.PurePayload)
   115  	test.Assert(t, msg.RPCInfo().Config().TransportProtocol()&transport.PurePayload == transport.PurePayload)
   116  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Thrift)
   117  }
   118  
   119  func TestProtobufProtocolCheck(t *testing.T) {
   120  	var req interface{}
   121  	var rbf remote.ByteBuffer
   122  	var ttheader bool
   123  	var flagBuf []byte
   124  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
   125  	msg := remote.NewMessage(req, mocks.ServiceInfo(), ri, remote.Call, remote.Server)
   126  
   127  	// 1. ttheader framed
   128  	flagBuf = make([]byte, 8*2)
   129  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
   130  	binary.BigEndian.PutUint32(flagBuf[4:8], TTHeaderMagic)
   131  	binary.BigEndian.PutUint32(flagBuf[12:], ProtobufV1Magic)
   132  	ttheader = IsTTHeader(flagBuf)
   133  	test.Assert(t, ttheader)
   134  	if ttheader {
   135  		flagBuf = flagBuf[8:]
   136  	}
   137  	rbf = remote.NewReaderBuffer(flagBuf)
   138  	err := checkPayload(flagBuf, msg, rbf, ttheader, 10)
   139  	test.Assert(t, err == nil, err)
   140  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.TTHeaderFramed)
   141  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf)
   142  
   143  	// 2. protobuf framed
   144  	flagBuf = make([]byte, 8*2)
   145  	binary.BigEndian.PutUint32(flagBuf, uint32(10))
   146  	binary.BigEndian.PutUint32(flagBuf[4:8], ProtobufV1Magic)
   147  	ttheader = IsTTHeader(flagBuf)
   148  	test.Assert(t, !ttheader)
   149  	rbf = remote.NewReaderBuffer(flagBuf)
   150  	err = checkPayload(flagBuf, msg, rbf, ttheader, 10)
   151  	test.Assert(t, err == nil, err)
   152  	err = checkPayload(flagBuf, msg, rbf, ttheader, 9)
   153  	test.Assert(t, err != nil, err)
   154  	test.Assert(t, msg.ProtocolInfo().TransProto == transport.Framed)
   155  	test.Assert(t, msg.ProtocolInfo().CodecType == serviceinfo.Protobuf)
   156  }
   157  
   158  func TestDefaultCodec_Encode_Decode(t *testing.T) {
   159  	remote.PutPayloadCode(serviceinfo.Thrift, mpc)
   160  
   161  	dc := NewDefaultCodec()
   162  	ctx := context.Background()
   163  	intKVInfo := prepareIntKVInfo()
   164  	strKVInfo := prepareStrKVInfo()
   165  	sendMsg := initClientSendMsg(transport.TTHeader)
   166  	sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
   167  	sendMsg.TransInfo().PutTransStrInfo(strKVInfo)
   168  
   169  	// test encode err
   170  	out := remote.NewReaderBuffer([]byte{})
   171  	err := dc.Encode(ctx, sendMsg, out)
   172  	test.Assert(t, err != nil)
   173  
   174  	// encode
   175  	out = remote.NewWriterBuffer(256)
   176  	err = dc.Encode(ctx, sendMsg, out)
   177  	test.Assert(t, err == nil, err)
   178  
   179  	// decode
   180  	recvMsg := initServerRecvMsg()
   181  	buf, err := out.Bytes()
   182  	test.Assert(t, err == nil, err)
   183  	in := remote.NewReaderBuffer(buf)
   184  	err = dc.Decode(ctx, recvMsg, in)
   185  	test.Assert(t, err == nil, err)
   186  
   187  	intKVInfoRecv := recvMsg.TransInfo().TransIntInfo()
   188  	strKVInfoRecv := recvMsg.TransInfo().TransStrInfo()
   189  	test.DeepEqual(t, intKVInfoRecv, intKVInfo)
   190  	test.DeepEqual(t, strKVInfoRecv, strKVInfo)
   191  	test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID())
   192  }
   193  
   194  func TestDefaultSizedCodec_Encode_Decode(t *testing.T) {
   195  	remote.PutPayloadCode(serviceinfo.Thrift, mpc)
   196  
   197  	smallDc := NewDefaultCodecWithSizeLimit(1)
   198  	largeDc := NewDefaultCodecWithSizeLimit(1024)
   199  	ctx := context.Background()
   200  	intKVInfo := prepareIntKVInfo()
   201  	strKVInfo := prepareStrKVInfo()
   202  	sendMsg := initClientSendMsg(transport.TTHeader)
   203  	sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
   204  	sendMsg.TransInfo().PutTransStrInfo(strKVInfo)
   205  
   206  	// encode
   207  	smallOut := remote.NewWriterBuffer(256)
   208  	largeOut := remote.NewWriterBuffer(256)
   209  	err := smallDc.Encode(ctx, sendMsg, smallOut)
   210  	test.Assert(t, err != nil, err)
   211  	err = largeDc.Encode(ctx, sendMsg, largeOut)
   212  	test.Assert(t, err == nil, err)
   213  
   214  	// decode
   215  	recvMsg := initServerRecvMsg()
   216  	smallBuf, _ := smallOut.Bytes()
   217  	largeBuf, _ := largeOut.Bytes()
   218  	err = smallDc.Decode(ctx, recvMsg, remote.NewReaderBuffer(smallBuf))
   219  	test.Assert(t, err != nil, err)
   220  	err = largeDc.Decode(ctx, recvMsg, remote.NewReaderBuffer(largeBuf))
   221  	test.Assert(t, err == nil, err)
   222  }
   223  
   224  func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) {
   225  	var req interface{}
   226  	remote.PutPayloadCode(serviceinfo.Thrift, mpc)
   227  	remote.PutPayloadCode(serviceinfo.Protobuf, mpc)
   228  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
   229  	codec := NewDefaultCodec()
   230  
   231  	// case 1: the payloadCodec of svcInfo is Protobuf, CodecType of message is Thrift
   232  	svcInfo := &serviceinfo.ServiceInfo{
   233  		PayloadCodec: serviceinfo.Protobuf,
   234  	}
   235  	msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Server)
   236  	msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.TTHeader, CodecType: serviceinfo.Thrift})
   237  	err := codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256))
   238  	test.Assert(t, err == nil, err)
   239  
   240  	// case 2: the payloadCodec of svcInfo is Thrift, CodecType of message is Protobuf
   241  	svcInfo = &serviceinfo.ServiceInfo{
   242  		PayloadCodec: serviceinfo.Thrift,
   243  	}
   244  	msg = remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Server)
   245  	msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.TTHeader, CodecType: serviceinfo.Protobuf})
   246  	err = codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256))
   247  	test.Assert(t, err != nil)
   248  	msg.SetProtocolInfo(remote.ProtocolInfo{TransProto: transport.Framed, CodecType: serviceinfo.Protobuf})
   249  	err = codec.Encode(context.Background(), msg, remote.NewWriterBuffer(256))
   250  	test.Assert(t, err == nil)
   251  }
   252  
   253  var mpc remote.PayloadCodec = mockPayloadCodec{}
   254  
   255  type mockPayloadCodec struct{}
   256  
   257  func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
   258  	WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out)
   259  	WriteString(message.RPCInfo().Invocation().MethodName(), out)
   260  	WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out)
   261  	return nil
   262  }
   263  
   264  func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   265  	magicAndMsgType, err := ReadUint32(in)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	if magicAndMsgType&MagicMask != ThriftV1Magic {
   270  		return errors.New("bad version")
   271  	}
   272  	msgType := magicAndMsgType & FrontMask
   273  	if err := UpdateMsgType(msgType, message); err != nil {
   274  		return err
   275  	}
   276  
   277  	methodName, _, err := ReadString(in)
   278  	if err != nil {
   279  		return err
   280  	}
   281  	if err = SetOrCheckMethodName(methodName, message); err != nil && msgType != uint32(remote.Exception) {
   282  		return err
   283  	}
   284  	seqID, err := ReadUint32(in)
   285  	if err != nil {
   286  		return err
   287  	}
   288  	if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) {
   289  		return err
   290  	}
   291  	return nil
   292  }
   293  
   294  func (m mockPayloadCodec) Name() string {
   295  	return "mock"
   296  }
   297  
   298  func TestCornerCase(t *testing.T) {
   299  	ctrl := gomock.NewController(t)
   300  	defer ctrl.Finish()
   301  
   302  	sendMsg := initClientSendMsg(transport.TTHeader)
   303  	sendMsg.SetProtocolInfo(remote.NewProtocolInfo(transport.Framed, serviceinfo.Thrift))
   304  
   305  	buffer := mocksremote.NewMockByteBuffer(ctrl)
   306  	buffer.EXPECT().MallocLen().Return(1024).AnyTimes()
   307  	buffer.EXPECT().Malloc(gomock.Any()).Return(nil, errors.New("error malloc")).AnyTimes()
   308  	err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer)
   309  	test.Assert(t, err.Error() == "error malloc")
   310  
   311  	mockey.PatchConvey("", t, func() {
   312  		mockey.Mock(remote.GetPayloadCodec).Return(nil, errors.New("err get payload codec")).Build()
   313  		buffer = mocksremote.NewMockByteBuffer(ctrl)
   314  		buffer.EXPECT().MallocLen().Return(1024).AnyTimes()
   315  		buffer.EXPECT().Malloc(gomock.Any()).Return(nil, nil).AnyTimes()
   316  		err := (&defaultCodec{}).EncodePayload(context.Background(), sendMsg, buffer)
   317  		test.Assert(t, err.Error() == "err get payload codec")
   318  	})
   319  }