github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/protobuf/protobuf_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package protobuf
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  
    24  	"google.golang.org/protobuf/proto"
    25  
    26  	"github.com/cloudwego/kitex/internal/mocks"
    27  	"github.com/cloudwego/kitex/internal/test"
    28  	"github.com/cloudwego/kitex/pkg/remote"
    29  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    30  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    31  	"github.com/cloudwego/kitex/transport"
    32  )
    33  
    34  var (
    35  	payloadCodec = &protobufCodec{}
    36  	svcInfo      = mocks.ServiceInfo()
    37  )
    38  
    39  func init() {
    40  	svcInfo.Methods["mock"] = serviceinfo.NewMethodInfo(nil, newMockReqArgs, nil, false)
    41  }
    42  
    43  func TestNormal(t *testing.T) {
    44  	ctx := context.Background()
    45  
    46  	// encode // client side
    47  	sendMsg := initSendMsg(transport.TTHeader)
    48  	out := remote.NewWriterBuffer(256)
    49  	err := payloadCodec.Marshal(ctx, sendMsg, out)
    50  	test.Assert(t, err == nil, err)
    51  
    52  	// decode server side
    53  	recvMsg := initRecvMsg()
    54  	buf, err := out.Bytes()
    55  	recvMsg.SetPayloadLen(len(buf))
    56  	test.Assert(t, err == nil, err)
    57  	in := remote.NewReaderBuffer(buf)
    58  	err = payloadCodec.Unmarshal(ctx, recvMsg, in)
    59  	test.Assert(t, err == nil, err)
    60  
    61  	// compare Req Arg
    62  	sendReq := (sendMsg.Data()).(*MockReqArgs).Req
    63  	recvReq := (recvMsg.Data()).(*MockReqArgs).Req
    64  	test.Assert(t, sendReq.Msg == recvReq.Msg)
    65  	test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList))
    66  	test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap))
    67  	for i, item := range sendReq.StrList {
    68  		test.Assert(t, item == recvReq.StrList[i])
    69  	}
    70  	for k := range sendReq.StrMap {
    71  		test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k])
    72  	}
    73  }
    74  
    75  func TestException(t *testing.T) {
    76  	ctx := context.Background()
    77  	ink := rpcinfo.NewInvocation("", "mock")
    78  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
    79  	errInfo := "mock exception"
    80  	transErr := remote.NewTransErrorWithMsg(remote.UnknownMethod, errInfo)
    81  	// encode server side
    82  	errMsg := initServerErrorMsg(transport.TTHeader, ri, transErr)
    83  	out := remote.NewWriterBuffer(256)
    84  	err := payloadCodec.Marshal(ctx, errMsg, out)
    85  	test.Assert(t, err == nil, err)
    86  
    87  	// decode client side
    88  	recvMsg := initClientRecvMsg(ri)
    89  	buf, err := out.Bytes()
    90  	recvMsg.SetPayloadLen(len(buf))
    91  	test.Assert(t, err == nil, err)
    92  	in := remote.NewReaderBuffer(buf)
    93  	err = payloadCodec.Unmarshal(ctx, recvMsg, in)
    94  	test.Assert(t, err != nil)
    95  	transErr, ok := err.(*remote.TransError)
    96  	test.Assert(t, ok)
    97  	test.Assert(t, err.Error() == errInfo)
    98  	test.Assert(t, transErr.Error() == errInfo)
    99  	test.Assert(t, transErr.TypeID() == remote.UnknownMethod)
   100  }
   101  
   102  func TestTransErrorUnwrap(t *testing.T) {
   103  	errMsg := "mock err"
   104  	transErr := remote.NewTransError(remote.InternalError, NewPbError(1000, errMsg))
   105  	uwErr, ok := transErr.Unwrap().(PBError)
   106  	test.Assert(t, ok)
   107  	test.Assert(t, uwErr.TypeID() == 1000)
   108  	test.Assert(t, transErr.Error() == errMsg)
   109  
   110  	uwErr2, ok := errors.Unwrap(transErr).(PBError)
   111  	test.Assert(t, ok)
   112  	test.Assert(t, uwErr2.TypeID() == 1000)
   113  	test.Assert(t, uwErr2.Error() == errMsg)
   114  }
   115  
   116  func BenchmarkNormalParallel(b *testing.B) {
   117  	ctx := context.Background()
   118  
   119  	b.ResetTimer()
   120  	b.RunParallel(func(pb *testing.PB) {
   121  		for pb.Next() {
   122  			// encode // client side
   123  			sendMsg := initSendMsg(transport.TTHeader)
   124  			out := remote.NewWriterBuffer(256)
   125  			err := payloadCodec.Marshal(ctx, sendMsg, out)
   126  			test.Assert(b, err == nil, err)
   127  
   128  			// decode server side
   129  			recvMsg := initRecvMsg()
   130  			buf, err := out.Bytes()
   131  			recvMsg.SetPayloadLen(len(buf))
   132  			test.Assert(b, err == nil, err)
   133  			in := remote.NewReaderBuffer(buf)
   134  			err = payloadCodec.Unmarshal(ctx, recvMsg, in)
   135  			test.Assert(b, err == nil, err)
   136  
   137  			// compare Req Arg
   138  			sendReq := (sendMsg.Data()).(*MockReqArgs).Req
   139  			recvReq := (recvMsg.Data()).(*MockReqArgs).Req
   140  			test.Assert(b, sendReq.Msg == recvReq.Msg)
   141  			test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList))
   142  			test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap))
   143  			for i, item := range sendReq.StrList {
   144  				test.Assert(b, item == recvReq.StrList[i])
   145  			}
   146  			for k := range sendReq.StrMap {
   147  				test.Assert(b, sendReq.StrMap[k] == recvReq.StrMap[k])
   148  			}
   149  		}
   150  	})
   151  }
   152  
   153  func initSendMsg(tp transport.Protocol) remote.Message {
   154  	var _args MockReqArgs
   155  	_args.Req = prepareReq()
   156  	ink := rpcinfo.NewInvocation("", "mock")
   157  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
   158  	msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Client)
   159  	msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
   160  	return msg
   161  }
   162  
   163  func initRecvMsg() remote.Message {
   164  	var _args MockReqArgs
   165  	ink := rpcinfo.NewInvocation("", "mock")
   166  	ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil)
   167  	msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Server)
   168  	return msg
   169  }
   170  
   171  func initServerErrorMsg(tp transport.Protocol, ri rpcinfo.RPCInfo, transErr *remote.TransError) remote.Message {
   172  	errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server)
   173  	errMsg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
   174  	return errMsg
   175  }
   176  
   177  func initClientRecvMsg(ri rpcinfo.RPCInfo) remote.Message {
   178  	var resp interface{}
   179  	clientRecvMsg := remote.NewMessage(resp, svcInfo, ri, remote.Reply, remote.Client)
   180  	return clientRecvMsg
   181  }
   182  
   183  func prepareReq() *MockReq {
   184  	strMap := make(map[string]string)
   185  	strMap["key1"] = "val1"
   186  	strMap["key2"] = "val2"
   187  	strList := []string{"str1", "str2"}
   188  	req := &MockReq{
   189  		Msg:     "MockReq",
   190  		StrMap:  strMap,
   191  		StrList: strList,
   192  	}
   193  	return req
   194  }
   195  
   196  func newMockReqArgs() interface{} {
   197  	return &MockReqArgs{}
   198  }
   199  
   200  type MockReqArgs struct {
   201  	Req *MockReq
   202  }
   203  
   204  func (p *MockReqArgs) Marshal(out []byte) ([]byte, error) {
   205  	if !p.IsSetReq() {
   206  		return out, nil
   207  	}
   208  	return proto.Marshal(p.Req)
   209  }
   210  
   211  func (p *MockReqArgs) Unmarshal(in []byte) error {
   212  	msg := new(MockReq)
   213  	if err := proto.Unmarshal(in, msg); err != nil {
   214  		return err
   215  	}
   216  	p.Req = msg
   217  	return nil
   218  }
   219  
   220  var STServiceTestObjReqArgsReqDEFAULT *MockReq
   221  
   222  func (p *MockReqArgs) GetReq() *MockReq {
   223  	if !p.IsSetReq() {
   224  		return STServiceTestObjReqArgsReqDEFAULT
   225  	}
   226  	return p.Req
   227  }
   228  
   229  func (p *MockReqArgs) IsSetReq() bool {
   230  	return p.Req != nil
   231  }