github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/thrift/thrift_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package thrift 18 19 import ( 20 "context" 21 "errors" 22 "testing" 23 24 "github.com/apache/thrift/lib/go/thrift" 25 26 "github.com/cloudwego/kitex/internal/mocks" 27 mt "github.com/cloudwego/kitex/internal/mocks/thrift" 28 "github.com/cloudwego/kitex/internal/test" 29 "github.com/cloudwego/kitex/pkg/remote" 30 "github.com/cloudwego/kitex/pkg/rpcinfo" 31 "github.com/cloudwego/kitex/pkg/serviceinfo" 32 "github.com/cloudwego/kitex/transport" 33 ) 34 35 var ( 36 payloadCodec = &thriftCodec{FastWrite | FastRead} 37 svcInfo = mocks.ServiceInfo() 38 ) 39 40 func init() { 41 svcInfo.Methods["mock"] = serviceinfo.NewMethodInfo(nil, newMockTestArgs, nil, false) 42 } 43 44 type mockWithContext struct { 45 ReadFunc func(ctx context.Context, method string, oprot thrift.TProtocol) error 46 WriteFunc func(ctx context.Context, oprot thrift.TProtocol) error 47 } 48 49 func (m *mockWithContext) Read(ctx context.Context, method string, oprot thrift.TProtocol) error { 50 if m.ReadFunc != nil { 51 return m.ReadFunc(ctx, method, oprot) 52 } 53 return nil 54 } 55 56 func (m *mockWithContext) Write(ctx context.Context, oprot thrift.TProtocol) error { 57 if m.WriteFunc != nil { 58 return m.WriteFunc(ctx, oprot) 59 } 60 return nil 61 } 62 63 func TestWithContext(t *testing.T) { 64 ctx := context.Background() 65 66 req := &mockWithContext{WriteFunc: func(ctx context.Context, oprot thrift.TProtocol) error { 67 return nil 68 }} 69 ink := rpcinfo.NewInvocation("", "mock") 70 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 71 msg := remote.NewMessage(req, svcInfo, ri, remote.Call, remote.Client) 72 msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) 73 out := remote.NewWriterBuffer(256) 74 err := payloadCodec.Marshal(ctx, msg, out) 75 test.Assert(t, err == nil, err) 76 77 { 78 resp := &mockWithContext{ReadFunc: func(ctx context.Context, method string, oprot thrift.TProtocol) error { return nil }} 79 ink := rpcinfo.NewInvocation("", "mock") 80 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 81 msg := remote.NewMessage(resp, svcInfo, ri, remote.Call, remote.Client) 82 msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, svcInfo.PayloadCodec)) 83 buf, err := out.Bytes() 84 test.Assert(t, err == nil, err) 85 msg.SetPayloadLen(len(buf)) 86 in := remote.NewReaderBuffer(buf) 87 err = payloadCodec.Unmarshal(ctx, msg, in) 88 test.Assert(t, err == nil, err) 89 } 90 } 91 92 func TestNormal(t *testing.T) { 93 ctx := context.Background() 94 95 // encode client side 96 sendMsg := initSendMsg(transport.TTHeader) 97 out := remote.NewWriterBuffer(256) 98 err := payloadCodec.Marshal(ctx, sendMsg, out) 99 test.Assert(t, err == nil, err) 100 101 // decode server side 102 recvMsg := initRecvMsg() 103 buf, err := out.Bytes() 104 recvMsg.SetPayloadLen(len(buf)) 105 test.Assert(t, err == nil, err) 106 in := remote.NewReaderBuffer(buf) 107 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 108 test.Assert(t, err == nil, err) 109 110 // compare Req Arg 111 sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req 112 recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req 113 test.Assert(t, sendReq.Msg == recvReq.Msg) 114 test.Assert(t, len(sendReq.StrList) == len(recvReq.StrList)) 115 test.Assert(t, len(sendReq.StrMap) == len(recvReq.StrMap)) 116 for i, item := range sendReq.StrList { 117 test.Assert(t, item == recvReq.StrList[i]) 118 } 119 for k := range sendReq.StrMap { 120 test.Assert(t, sendReq.StrMap[k] == recvReq.StrMap[k]) 121 } 122 } 123 124 func BenchmarkNormalParallel(b *testing.B) { 125 ctx := context.Background() 126 127 b.ResetTimer() 128 b.RunParallel(func(pb *testing.PB) { 129 for pb.Next() { 130 // encode // client side 131 sendMsg := initSendMsg(transport.TTHeader) 132 out := remote.NewWriterBuffer(256) 133 err := payloadCodec.Marshal(ctx, sendMsg, out) 134 test.Assert(b, err == nil, err) 135 136 // decode server side 137 recvMsg := initRecvMsg() 138 buf, err := out.Bytes() 139 recvMsg.SetPayloadLen(len(buf)) 140 test.Assert(b, err == nil, err) 141 in := remote.NewReaderBuffer(buf) 142 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 143 test.Assert(b, err == nil, err) 144 145 // compare Req Arg 146 sendReq := (sendMsg.Data()).(*mt.MockTestArgs).Req 147 recvReq := (recvMsg.Data()).(*mt.MockTestArgs).Req 148 test.Assert(b, sendReq.Msg == recvReq.Msg) 149 test.Assert(b, len(sendReq.StrList) == len(recvReq.StrList)) 150 test.Assert(b, len(sendReq.StrMap) == len(recvReq.StrMap)) 151 for i, item := range sendReq.StrList { 152 test.Assert(b, item == recvReq.StrList[i]) 153 } 154 for k := range sendReq.StrMap { 155 test.Assert(b, sendReq.StrMap[k] == recvReq.StrMap[k]) 156 } 157 } 158 }) 159 } 160 161 func TestException(t *testing.T) { 162 ctx := context.Background() 163 ink := rpcinfo.NewInvocation("", "mock") 164 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 165 errInfo := "mock exception" 166 transErr := remote.NewTransErrorWithMsg(remote.UnknownMethod, errInfo) 167 // encode server side 168 errMsg := initServerErrorMsg(transport.TTHeader, ri, transErr) 169 out := remote.NewWriterBuffer(256) 170 err := payloadCodec.Marshal(ctx, errMsg, out) 171 test.Assert(t, err == nil, err) 172 173 // decode client side 174 recvMsg := initClientRecvMsg(ri) 175 buf, err := out.Bytes() 176 recvMsg.SetPayloadLen(len(buf)) 177 test.Assert(t, err == nil, err) 178 in := remote.NewReaderBuffer(buf) 179 err = payloadCodec.Unmarshal(ctx, recvMsg, in) 180 test.Assert(t, err != nil) 181 transErr, ok := err.(*remote.TransError) 182 test.Assert(t, ok) 183 test.Assert(t, err.Error() == errInfo) 184 test.Assert(t, transErr.TypeID() == remote.UnknownMethod) 185 } 186 187 func TestTransErrorUnwrap(t *testing.T) { 188 errMsg := "mock err" 189 transErr := remote.NewTransError(remote.InternalError, thrift.NewTApplicationException(1000, errMsg)) 190 uwErr, ok := transErr.Unwrap().(thrift.TApplicationException) 191 test.Assert(t, ok) 192 test.Assert(t, uwErr.TypeId() == 1000) 193 test.Assert(t, transErr.Error() == errMsg) 194 195 uwErr2, ok := errors.Unwrap(transErr).(thrift.TApplicationException) 196 test.Assert(t, ok) 197 test.Assert(t, uwErr2.TypeId() == 1000) 198 test.Assert(t, uwErr2.Error() == errMsg) 199 } 200 201 func initSendMsg(tp transport.Protocol) remote.Message { 202 var _args mt.MockTestArgs 203 _args.Req = prepareReq() 204 ink := rpcinfo.NewInvocation("", "mock") 205 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 206 msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Client) 207 msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) 208 return msg 209 } 210 211 func initRecvMsg() remote.Message { 212 var _args mt.MockTestArgs 213 ink := rpcinfo.NewInvocation("", "mock") 214 ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) 215 msg := remote.NewMessage(&_args, svcInfo, ri, remote.Call, remote.Server) 216 return msg 217 } 218 219 func initServerErrorMsg(tp transport.Protocol, ri rpcinfo.RPCInfo, transErr *remote.TransError) remote.Message { 220 errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server) 221 errMsg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) 222 return errMsg 223 } 224 225 func initClientRecvMsg(ri rpcinfo.RPCInfo) remote.Message { 226 var resp interface{} 227 clientRecvMsg := remote.NewMessage(resp, svcInfo, ri, remote.Reply, remote.Client) 228 return clientRecvMsg 229 } 230 231 func prepareReq() *mt.MockReq { 232 strMap := make(map[string]string) 233 strMap["key1"] = "val1" 234 strMap["key2"] = "val2" 235 strList := []string{"str1", "str2"} 236 req := &mt.MockReq{ 237 Msg: "MockReq", 238 StrMap: strMap, 239 StrList: strList, 240 } 241 return req 242 } 243 244 func newMockTestArgs() interface{} { 245 return mt.NewMockTestArgs() 246 }