github.com/cloudwego/kitex@v0.9.0/pkg/generic/binarythrift_codec_test.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package generic 18 19 import ( 20 "context" 21 "testing" 22 23 "github.com/apache/thrift/lib/go/thrift" 24 25 kt "github.com/cloudwego/kitex/internal/mocks/thrift" 26 "github.com/cloudwego/kitex/internal/test" 27 "github.com/cloudwego/kitex/pkg/remote" 28 "github.com/cloudwego/kitex/pkg/rpcinfo" 29 "github.com/cloudwego/kitex/pkg/serviceinfo" 30 "github.com/cloudwego/kitex/pkg/utils" 31 ) 32 33 func TestBinaryThriftCodec(t *testing.T) { 34 req := kt.NewMockReq() 35 args := kt.NewMockTestArgs() 36 args.Req = req 37 // encode 38 rc := utils.NewThriftMessageCodec() 39 buf, err := rc.Encode("mock", thrift.CALL, 100, args) 40 test.Assert(t, err == nil, err) 41 42 btc := &binaryThriftCodec{thriftCodec} 43 cliMsg := &mockMessage{ 44 RPCInfoFunc: func() rpcinfo.RPCInfo { 45 return newMockRPCInfo() 46 }, 47 RPCRoleFunc: func() remote.RPCRole { 48 return remote.Client 49 }, 50 DataFunc: func() interface{} { 51 return &Args{ 52 Request: buf, 53 Method: "mock", 54 } 55 }, 56 } 57 seqID, err := GetSeqID(cliMsg.Data().(*Args).Request.(binaryReqType)) 58 test.Assert(t, err == nil, err) 59 test.Assert(t, seqID == 100, seqID) 60 61 rwbuf := remote.NewReaderWriterBuffer(1024) 62 // change seqID to 1 63 err = btc.Marshal(context.Background(), cliMsg, rwbuf) 64 test.Assert(t, err == nil, err) 65 seqID, err = GetSeqID(cliMsg.Data().(*Args).Request.(binaryReqType)) 66 test.Assert(t, err == nil, err) 67 test.Assert(t, seqID == 1, seqID) 68 69 // server side 70 arg := &Args{} 71 svrMsg := &mockMessage{ 72 RPCInfoFunc: func() rpcinfo.RPCInfo { 73 return newMockRPCInfo() 74 }, 75 RPCRoleFunc: func() remote.RPCRole { 76 return remote.Server 77 }, 78 DataFunc: func() interface{} { 79 return arg 80 }, 81 PayloadLenFunc: func() int { 82 return rwbuf.ReadableLen() 83 }, 84 ServiceInfoFunc: func() *serviceinfo.ServiceInfo { 85 return ServiceInfo(serviceinfo.Thrift) 86 }, 87 } 88 err = btc.Unmarshal(context.Background(), svrMsg, rwbuf) 89 test.Assert(t, err == nil, err) 90 reqBuf := svrMsg.Data().(*Args).Request.(binaryReqType) 91 seqID, err = GetSeqID(reqBuf) 92 test.Assert(t, err == nil, err) 93 test.Assert(t, seqID == 1, seqID) 94 95 var req2 kt.MockTestArgs 96 method, seqID2, err2 := rc.Decode(reqBuf, &req2) 97 test.Assert(t, err2 == nil, err) 98 test.Assert(t, seqID2 == 1, seqID) 99 test.Assert(t, method == "mock", method) 100 } 101 102 func TestBinaryThriftCodecExceptionError(t *testing.T) { 103 ctx := context.Background() 104 btc := &binaryThriftCodec{thriftCodec} 105 cliMsg := &mockMessage{ 106 RPCInfoFunc: func() rpcinfo.RPCInfo { 107 return newEmptyMethodRPCInfo() 108 }, 109 RPCRoleFunc: func() remote.RPCRole { 110 return remote.Server 111 }, 112 MessageTypeFunc: func() remote.MessageType { 113 return remote.Exception 114 }, 115 } 116 117 rwbuf := remote.NewReaderWriterBuffer(1024) 118 // test data is empty 119 err := btc.Marshal(ctx, cliMsg, rwbuf) 120 test.Assert(t, err.Error() == "invalid marshal data in rawThriftBinaryCodec: nil") 121 cliMsg.DataFunc = func() interface{} { 122 return &remote.TransError{} 123 } 124 125 // empty method 126 err = btc.Marshal(ctx, cliMsg, rwbuf) 127 test.Assert(t, err.Error() == "rawThriftBinaryCodec Marshal exception failed, err: empty methodName in thrift Marshal") 128 129 cliMsg.RPCInfoFunc = func() rpcinfo.RPCInfo { 130 return newMockRPCInfo() 131 } 132 err = btc.Marshal(ctx, cliMsg, rwbuf) 133 test.Assert(t, err == nil) 134 err = btc.Unmarshal(ctx, cliMsg, rwbuf) 135 test.Assert(t, err.Error() == "unknown application exception") 136 137 // test server role 138 cliMsg.MessageTypeFunc = func() remote.MessageType { 139 return remote.Call 140 } 141 cliMsg.DataFunc = func() interface{} { 142 return &Result{ 143 Success: binaryReqType{}, 144 } 145 } 146 err = btc.Marshal(ctx, cliMsg, rwbuf) 147 test.Assert(t, err == nil) 148 } 149 150 func newMockRPCInfo() rpcinfo.RPCInfo { 151 c := rpcinfo.NewEndpointInfo("", "", nil, nil) 152 s := rpcinfo.NewEndpointInfo("", "", nil, nil) 153 ink := rpcinfo.NewInvocation("", "mock") 154 ri := rpcinfo.NewRPCInfo(c, s, ink, nil, rpcinfo.NewRPCStats()) 155 return ri 156 } 157 158 func newEmptyMethodRPCInfo() rpcinfo.RPCInfo { 159 c := rpcinfo.NewEndpointInfo("", "", nil, nil) 160 s := rpcinfo.NewEndpointInfo("", "", nil, nil) 161 ink := rpcinfo.NewInvocation("", "") 162 ri := rpcinfo.NewRPCInfo(c, s, ink, nil, nil) 163 return ri 164 } 165 166 var _ remote.Message = &mockMessage{} 167 168 type mockMessage struct { 169 RPCInfoFunc func() rpcinfo.RPCInfo 170 ServiceInfoFunc func() *serviceinfo.ServiceInfo 171 SetServiceInfoFunc func(svcName, methodName string) (*serviceinfo.ServiceInfo, error) 172 DataFunc func() interface{} 173 NewDataFunc func(method string) (ok bool) 174 MessageTypeFunc func() remote.MessageType 175 SetMessageTypeFunc func(remote.MessageType) 176 RPCRoleFunc func() remote.RPCRole 177 PayloadLenFunc func() int 178 SetPayloadLenFunc func(size int) 179 TransInfoFunc func() remote.TransInfo 180 TagsFunc func() map[string]interface{} 181 ProtocolInfoFunc func() remote.ProtocolInfo 182 SetProtocolInfoFunc func(remote.ProtocolInfo) 183 PayloadCodecFunc func() remote.PayloadCodec 184 SetPayloadCodecFunc func(pc remote.PayloadCodec) 185 RecycleFunc func() 186 } 187 188 func (m *mockMessage) RPCInfo() rpcinfo.RPCInfo { 189 if m.RPCInfoFunc != nil { 190 return m.RPCInfoFunc() 191 } 192 return nil 193 } 194 195 func (m *mockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) { 196 if m.ServiceInfoFunc != nil { 197 return m.ServiceInfoFunc() 198 } 199 return 200 } 201 202 func (m *mockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) { 203 if m.SetServiceInfoFunc != nil { 204 return m.SetServiceInfoFunc(svcName, methodName) 205 } 206 return nil, nil 207 } 208 209 func (m *mockMessage) Data() interface{} { 210 if m.DataFunc != nil { 211 return m.DataFunc() 212 } 213 return nil 214 } 215 216 func (m *mockMessage) NewData(method string) (ok bool) { 217 if m.NewDataFunc != nil { 218 return m.NewDataFunc(method) 219 } 220 return false 221 } 222 223 func (m *mockMessage) MessageType() (mt remote.MessageType) { 224 if m.MessageTypeFunc != nil { 225 return m.MessageTypeFunc() 226 } 227 return 228 } 229 230 func (m *mockMessage) SetMessageType(mt remote.MessageType) { 231 if m.SetMessageTypeFunc != nil { 232 m.SetMessageTypeFunc(mt) 233 } 234 } 235 236 func (m *mockMessage) RPCRole() (r remote.RPCRole) { 237 if m.RPCRoleFunc != nil { 238 return m.RPCRoleFunc() 239 } 240 return 241 } 242 243 func (m *mockMessage) PayloadLen() int { 244 if m.PayloadLenFunc != nil { 245 return m.PayloadLenFunc() 246 } 247 return 0 248 } 249 250 func (m *mockMessage) SetPayloadLen(size int) { 251 if m.SetPayloadLenFunc != nil { 252 m.SetPayloadLenFunc(size) 253 } 254 } 255 256 func (m *mockMessage) TransInfo() remote.TransInfo { 257 if m.TransInfoFunc != nil { 258 return m.TransInfoFunc() 259 } 260 return nil 261 } 262 263 func (m *mockMessage) Tags() map[string]interface{} { 264 if m.TagsFunc != nil { 265 return m.TagsFunc() 266 } 267 return nil 268 } 269 270 func (m *mockMessage) ProtocolInfo() (pi remote.ProtocolInfo) { 271 if m.ProtocolInfoFunc != nil { 272 return m.ProtocolInfoFunc() 273 } 274 return 275 } 276 277 func (m *mockMessage) SetProtocolInfo(pi remote.ProtocolInfo) { 278 if m.SetProtocolInfoFunc != nil { 279 m.SetProtocolInfoFunc(pi) 280 } 281 } 282 283 func (m *mockMessage) PayloadCodec() remote.PayloadCodec { 284 if m.PayloadCodecFunc != nil { 285 return m.PayloadCodecFunc() 286 } 287 return nil 288 } 289 290 func (m *mockMessage) SetPayloadCodec(pc remote.PayloadCodec) { 291 if m.SetPayloadCodecFunc != nil { 292 m.SetPayloadCodecFunc(pc) 293 } 294 } 295 296 func (m *mockMessage) Recycle() { 297 if m.RecycleFunc != nil { 298 m.RecycleFunc() 299 } 300 }