github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/thrift_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  
    24  	"github.com/apache/thrift/lib/go/thrift"
    25  
    26  	"github.com/cloudwego/kitex/internal/mocks"
    27  	mt "github.com/cloudwego/kitex/internal/mocks/thrift"
    28  	"github.com/cloudwego/kitex/internal/test"
    29  	"github.com/cloudwego/kitex/pkg/remote"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    31  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    32  	"github.com/cloudwego/kitex/transport"
    33  )
    34  
    35  var (
    36  	payloadCodec = &thriftCodec{FastWrite | FastRead}
    37  	svcInfo      = mocks.ServiceInfo()
    38  )
    39  
    40  func init() {
    41  	svcInfo.Methods["mock"] = serviceinfo.NewMethodInfo(nil, newMockTestArgs, nil, false)
    42  }
    43  
    44  type mockWithContext struct {
    45  	ReadFunc  func(ctx context.Context, method string, oprot thrift.TProtocol) error
    46  	WriteFunc func(ctx context.Context, oprot thrift.TProtocol) error
    47  }
    48  
    49  func (m *mockWithContext) Read(ctx context.Context, method string, oprot thrift.TProtocol) error {
    50  	if m.ReadFunc != nil {
    51  		return m.ReadFunc(ctx, method, oprot)
    52  	}
    53  	return nil
    54  }
    55  
    56  func (m *mockWithContext) Write(ctx context.Context, oprot thrift.TProtocol) error {
    57  	if m.WriteFunc != nil {
    58  		return m.WriteFunc(ctx, oprot)
    59  	}
    60  	return nil
    61  }
    62  
    63  func TestWithContext(t *testing.T) {
    64  	ctx := context.Background()
    65  
    66  	req := &mockWithContext{WriteFunc: func(ctx context.Context, oprot thrift.TProtocol) error {
    67  		return nil
    68  	}}
    69  	ink := rpcinfo.NewInvocation("", "mock")
    70  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
    71  	msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client)
    72  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec))
    73  	out := remote.NewWriterBuffer(256)
    74  	err := payloadCodec.Marshal(ctx, msg, out)
    75  	test.Assert(t, err == nil, err)
    76  
    77  	{
    78  		resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error { return nil }}
    79  		ink := rpcinfo.NewInvocation("", "mock")
    80  		ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
    81  		msg := remote.NewMessage(resp, svcInfo, ri, remote.Call, remote.Client)
    82  		msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec))
    83  		buf, err := out.Bytes()
    84  		test.Assert(t, err == nil, err)
    85  		msg.SetPayloadLen(len(buf))
    86  		in := remote.NewReaderBuffer(buf)
    87  		err = payloadCodec.Unmarshal(ctx, msg, in)
    88  		test.Assert(t, err == nil, err)
    89  	}
    90  }
    91  
    92  func TestNormal(t *testing.T) {
    93  	ctx := context.Background()
    94  
    95  	// encode client side
    96  	sendMsg := initSendMsg(transport.TTHeader)
    97  	out := remote.NewWriterBuffer(256)
    98  	err := payloadCodec.Marshal(ctx, sendMsg, out)
    99  	test.Assert(t, err == nil, err)
   100  
   101  	// decode server side
   102  	recvMsg := initRecvMsg()
   103  	buf, err := out.Bytes()
   104  	recvMsg.SetPayloadLen(len(buf))
   105  	test.Assert(t, err == nil, err)
   106  	in := remote.NewReaderBuffer(buf)
   107  	err = payloadCodec.Unmarshal(ctx, recvMsg, in)
   108  	test.Assert(t, err == nil, err)
   109  
   110  	// compare Req Arg
   111  	sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req
   112  	recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req
   113  	test.Assert(t, sendReq.Msg == recvReq.Msg)
   114  	test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList))
   115  	test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap))
   116  	for i, item := range sendReq.StrList {
   117  		test.Assert(t, item == recvReq.StrList[i])
   118  	}
   119  	for k := range sendReq.StrMap {
   120  		test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k])
   121  	}
   122  }
   123  
   124  func BenchmarkNormalParallel(b *testing.B) {
   125  	ctx := context.Background()
   126  
   127  	b.ResetTimer()
   128  	b.RunParallel(func(pb *testing.PB) {
   129  		for pb.Next() {
   130  			// encode // client side
   131  			sendMsg := initSendMsg(transport.TTHeader)
   132  			out := remote.NewWriterBuffer(256)
   133  			err := payloadCodec.Marshal(ctx, sendMsg, out)
   134  			test.Assert(b, err == nil, err)
   135  
   136  			// decode server side
   137  			recvMsg := initRecvMsg()
   138  			buf, err := out.Bytes()
   139  			recvMsg.SetPayloadLen(len(buf))
   140  			test.Assert(b, err == nil, err)
   141  			in := remote.NewReaderBuffer(buf)
   142  			err = payloadCodec.Unmarshal(ctx, recvMsg, in)
   143  			test.Assert(b, err == nil, err)
   144  
   145  			// compare Req Arg
   146  			sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req
   147  			recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req
   148  			test.Assert(b, sendReq.Msg == recvReq.Msg)
   149  			test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList))
   150  			test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap))
   151  			for i, item := range sendReq.StrList {
   152  				test.Assert(b, item == recvReq.StrList[i])
   153  			}
   154  			for k := range sendReq.StrMap {
   155  				test.Assert(b, sendReq.StrMap[k] == recvReq.StrMap[k])
   156  			}
   157  		}
   158  	})
   159  }
   160  
   161  func TestException(t *testing.T) {
   162  	ctx := context.Background()
   163  	ink := rpcinfo.NewInvocation("", "mock")
   164  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
   165  	errInfo := "mock exception"
   166  	transErr := remote.NewTransErrorWithMsg(remote.UnknownMethod, errInfo)
   167  	// encode server side
   168  	errMsg := initServerErrorMsg(transport.TTHeader, ri, transErr)
   169  	out := remote.NewWriterBuffer(256)
   170  	err := payloadCodec.Marshal(ctx, errMsg, out)
   171  	test.Assert(t, err == nil, err)
   172  
   173  	// decode client side
   174  	recvMsg := initClientRecvMsg(ri)
   175  	buf, err := out.Bytes()
   176  	recvMsg.SetPayloadLen(len(buf))
   177  	test.Assert(t, err == nil, err)
   178  	in := remote.NewReaderBuffer(buf)
   179  	err = payloadCodec.Unmarshal(ctx, recvMsg, in)
   180  	test.Assert(t, err != nil)
   181  	transErr, ok := err.(*remote.TransError)
   182  	test.Assert(t, ok)
   183  	test.Assert(t, err.Error() == errInfo)
   184  	test.Assert(t, transErr.TypeID() == remote.UnknownMethod)
   185  }
   186  
   187  func TestTransErrorUnwrap(t *testing.T) {
   188  	errMsg := "mock err"
   189  	transErr := remote.NewTransError(remote.InternalError, thrift.NewTApplicationException(1000, errMsg))
   190  	uwErr, ok := transErr.Unwrap().(thrift.TApplicationException)
   191  	test.Assert(t, ok)
   192  	test.Assert(t, uwErr.TypeId() == 1000)
   193  	test.Assert(t, transErr.Error() == errMsg)
   194  
   195  	uwErr2, ok := errors.Unwrap(transErr).(thrift.TApplicationException)
   196  	test.Assert(t, ok)
   197  	test.Assert(t, uwErr2.TypeId() == 1000)
   198  	test.Assert(t, uwErr2.Error() == errMsg)
   199  }
   200  
   201  func initSendMsg(tp transport.Protocol) remote.Message {
   202  	var _args mt.MockTestArgs
   203  	_args.Req = prepareReq()
   204  	ink := rpcinfo.NewInvocation("", "mock")
   205  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
   206  	msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Client)
   207  	msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
   208  	return msg
   209  }
   210  
   211  func initRecvMsg() remote.Message {
   212  	var _args mt.MockTestArgs
   213  	ink := rpcinfo.NewInvocation("", "mock")
   214  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
   215  	msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Server)
   216  	return msg
   217  }
   218  
   219  func initServerErrorMsg(tp transport.Protocol, ri rpcinfo.RPCInfo, transErr *remote.TransError) remote.Message {
   220  	errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server)
   221  	errMsg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
   222  	return errMsg
   223  }
   224  
   225  func initClientRecvMsg(ri rpcinfo.RPCInfo) remote.Message {
   226  	var resp interface{}
   227  	clientRecvMsg := remote.NewMessage(resp, svcInfo, ri, remote.Reply, remote.Client)
   228  	return clientRecvMsg
   229  }
   230  
   231  func prepareReq() *mt.MockReq {
   232  	strMap := make(map[string]string)
   233  	strMap["key1"] = "val1"
   234  	strMap["key2"] = "val2"
   235  	strList := []string{"str1", "str2"}
   236  	req := &mt.MockReq{
   237  		Msg:     "MockReq",
   238  		StrMap:  strMap,
   239  		StrList: strList,
   240  	}
   241  	return req
   242  }
   243  
   244  func newMockTestArgs() interface{} {
   245  	return mt.NewMockTestArgs()
   246  }