github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/protobuf/protobuf_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package protobuf 18 19 import ( 20 "context" 21 "errors" 22 "testing" 23 24 "google.golang.org/protobuf/proto" 25 26 "github.com/cloudwego/kitex/internal/mocks" 27 "github.com/cloudwego/kitex/internal/test" 28 "github.com/cloudwego/kitex/pkg/remote" 29 "github.com/cloudwego/kitex/pkg/rpcinfo" 30 "github.com/cloudwego/kitex/pkg/serviceinfo" 31 "github.com/cloudwego/kitex/transport" 32 ) 33 34 var ( 35 payloadCodec = &protobufCodec{} 36 svcInfo = mocks.ServiceInfo() 37 ) 38 39 func init() { 40 svcInfo.Methods["mock"] = serviceinfo.NewMethodInfo(nil, newMockReqArgs, nil, false) 41 } 42 43 func TestNormal(t *testing.T) { 44 ctx := context.Background() 45 46 // encode // client side 47 sendMsg := initSendMsg(transport.TTHeader) 48 out := remote.NewWriterBuffer(256) 49 err := payloadCodec.Marshal(ctx, sendMsg, out) 50 test.Assert(t, err == nil, err) 51 52 // decode server side 53 recvMsg := initRecvMsg() 54 buf, err := out.Bytes() 55 recvMsg.SetPayloadLen(len(buf)) 56 test.Assert(t, err == nil, err) 57 in := remote.NewReaderBuffer(buf) 58 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 59 test.Assert(t, err == nil, err) 60 61 // compare Req Arg 62 sendReq := (sendMsg.Data()).(*MockReqArgs).Req 63 recvReq := (recvMsg.Data()).(*MockReqArgs).Req 64 test.Assert(t, sendReq.Msg == recvReq.Msg) 65 test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList)) 66 test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap)) 67 for i, item := range sendReq.StrList { 68 test.Assert(t, item == recvReq.StrList[i]) 69 } 70 for k := range sendReq.StrMap { 71 test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k]) 72 } 73 } 74 75 func TestException(t *testing.T) { 76 ctx := context.Background() 77 ink := rpcinfo.NewInvocation("", "mock") 78 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 79 errInfo := "mock exception" 80 transErr := remote.NewTransErrorWithMsg(remote.UnknownMethod, errInfo) 81 // encode server side 82 errMsg := initServerErrorMsg(transport.TTHeader, ri, transErr) 83 out := remote.NewWriterBuffer(256) 84 err := payloadCodec.Marshal(ctx, errMsg, out) 85 test.Assert(t, err == nil, err) 86 87 // decode client side 88 recvMsg := initClientRecvMsg(ri) 89 buf, err := out.Bytes() 90 recvMsg.SetPayloadLen(len(buf)) 91 test.Assert(t, err == nil, err) 92 in := remote.NewReaderBuffer(buf) 93 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 94 test.Assert(t, err != nil) 95 transErr, ok := err.(*remote.TransError) 96 test.Assert(t, ok) 97 test.Assert(t, err.Error() == errInfo) 98 test.Assert(t, transErr.Error() == errInfo) 99 test.Assert(t, transErr.TypeID() == remote.UnknownMethod) 100 } 101 102 func TestTransErrorUnwrap(t *testing.T) { 103 errMsg := "mock err" 104 transErr := remote.NewTransError(remote.InternalError, NewPbError(1000, errMsg)) 105 uwErr, ok := transErr.Unwrap().(PBError) 106 test.Assert(t, ok) 107 test.Assert(t, uwErr.TypeID() == 1000) 108 test.Assert(t, transErr.Error() == errMsg) 109 110 uwErr2, ok := errors.Unwrap(transErr).(PBError) 111 test.Assert(t, ok) 112 test.Assert(t, uwErr2.TypeID() == 1000) 113 test.Assert(t, uwErr2.Error() == errMsg) 114 } 115 116 func BenchmarkNormalParallel(b *testing.B) { 117 ctx := context.Background() 118 119 b.ResetTimer() 120 b.RunParallel(func(pb *testing.PB) { 121 for pb.Next() { 122 // encode // client side 123 sendMsg := initSendMsg(transport.TTHeader) 124 out := remote.NewWriterBuffer(256) 125 err := payloadCodec.Marshal(ctx, sendMsg, out) 126 test.Assert(b, err == nil, err) 127 128 // decode server side 129 recvMsg := initRecvMsg() 130 buf, err := out.Bytes() 131 recvMsg.SetPayloadLen(len(buf)) 132 test.Assert(b, err == nil, err) 133 in := remote.NewReaderBuffer(buf) 134 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 135 test.Assert(b, err == nil, err) 136 137 // compare Req Arg 138 sendReq := (sendMsg.Data()).(*MockReqArgs).Req 139 recvReq := (recvMsg.Data()).(*MockReqArgs).Req 140 test.Assert(b, sendReq.Msg == recvReq.Msg) 141 test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList)) 142 test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap)) 143 for i, item := range sendReq.StrList { 144 test.Assert(b, item == recvReq.StrList[i]) 145 } 146 for k := range sendReq.StrMap { 147 test.Assert(b, sendReq.StrMap[k] == recvReq.StrMap[k]) 148 } 149 } 150 }) 151 } 152 153 func initSendMsg(tp transport.Protocol) remote.Message { 154 var _args MockReqArgs 155 _args.Req = prepareReq() 156 ink := rpcinfo.NewInvocation("", "mock") 157 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 158 msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Client) 159 msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) 160 return msg 161 } 162 163 func initRecvMsg() remote.Message { 164 var _args MockReqArgs 165 ink := rpcinfo.NewInvocation("", "mock") 166 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 167 msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Server) 168 return msg 169 } 170 171 func initServerErrorMsg(tp transport.Protocol, ri rpcinfo.RPCInfo, transErr *remote.TransError) remote.Message { 172 errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server) 173 errMsg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) 174 return errMsg 175 } 176 177 func initClientRecvMsg(ri rpcinfo.RPCInfo) remote.Message { 178 var resp interface{} 179 clientRecvMsg := remote.NewMessage(resp, svcInfo, ri, remote.Reply, remote.Client) 180 return clientRecvMsg 181 } 182 183 func prepareReq() *MockReq { 184 strMap := make(map[string]string) 185 strMap["key1"] = "val1" 186 strMap["key2"] = "val2" 187 strList := []string{"str1", "str2"} 188 req := &MockReq{ 189 Msg: "MockReq", 190 StrMap: strMap, 191 StrList: strList, 192 } 193 return req 194 } 195 196 func newMockReqArgs() interface{} { 197 return &MockReqArgs{} 198 } 199 200 type MockReqArgs struct { 201 Req *MockReq 202 } 203 204 func (p *MockReqArgs) Marshal(out []byte) ([]byte, error) { 205 if !p.IsSetReq() { 206 return out, nil 207 } 208 return proto.Marshal(p.Req) 209 } 210 211 func (p *MockReqArgs) Unmarshal(in []byte) error { 212 msg := new(MockReq) 213 if err := proto.Unmarshal(in, msg); err != nil { 214 return err 215 } 216 p.Req = msg 217 return nil 218 } 219 220 var STServiceTestObjReqArgsReqDEFAULT *MockReq 221 222 func (p *MockReqArgs) GetReq() *MockReq { 223 if !p.IsSetReq() { 224 return STServiceTestObjReqArgsReqDEFAULT 225 } 226 return p.Req 227 } 228 229 func (p *MockReqArgs) IsSetReq() bool { 230 return p.Req != nil 231 }