github.com/cloudwego/kitex@v0.9.0/pkg/generic/binarythrift_codec_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package generic
    18  
    19  import (
    20  	"context"
    21  	"testing"
    22  
    23  	"github.com/apache/thrift/lib/go/thrift"
    24  
    25  	kt "github.com/cloudwego/kitex/internal/mocks/thrift"
    26  	"github.com/cloudwego/kitex/internal/test"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    29  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    30  	"github.com/cloudwego/kitex/pkg/utils"
    31  )
    32  
    33  func TestBinaryThriftCodec(t *testing.T) {
    34  	req := kt.NewMockReq()
    35  	args := kt.NewMockTestArgs()
    36  	args.Req = req
    37  	// encode
    38  	rc := utils.NewThriftMessageCodec()
    39  	buf, err := rc.Encode("mock", thrift.CALL, 100, args)
    40  	test.Assert(t, err == nil, err)
    41  
    42  	btc := &binaryThriftCodec{thriftCodec}
    43  	cliMsg := &mockMessage{
    44  		RPCInfoFunc: func() rpcinfo.RPCInfo {
    45  			return newMockRPCInfo()
    46  		},
    47  		RPCRoleFunc: func() remote.RPCRole {
    48  			return remote.Client
    49  		},
    50  		DataFunc: func() interface{} {
    51  			return &Args{
    52  				Request: buf,
    53  				Method:  "mock",
    54  			}
    55  		},
    56  	}
    57  	seqID, err := GetSeqID(cliMsg.Data().(*Args).Request.(binaryReqType))
    58  	test.Assert(t, err == nil, err)
    59  	test.Assert(t, seqID == 100, seqID)
    60  
    61  	rwbuf := remote.NewReaderWriterBuffer(1024)
    62  	// change seqID to 1
    63  	err = btc.Marshal(context.Background(), cliMsg, rwbuf)
    64  	test.Assert(t, err == nil, err)
    65  	seqID, err = GetSeqID(cliMsg.Data().(*Args).Request.(binaryReqType))
    66  	test.Assert(t, err == nil, err)
    67  	test.Assert(t, seqID == 1, seqID)
    68  
    69  	// server side
    70  	arg := &Args{}
    71  	svrMsg := &mockMessage{
    72  		RPCInfoFunc: func() rpcinfo.RPCInfo {
    73  			return newMockRPCInfo()
    74  		},
    75  		RPCRoleFunc: func() remote.RPCRole {
    76  			return remote.Server
    77  		},
    78  		DataFunc: func() interface{} {
    79  			return arg
    80  		},
    81  		PayloadLenFunc: func() int {
    82  			return rwbuf.ReadableLen()
    83  		},
    84  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
    85  			return ServiceInfo(serviceinfo.Thrift)
    86  		},
    87  	}
    88  	err = btc.Unmarshal(context.Background(), svrMsg, rwbuf)
    89  	test.Assert(t, err == nil, err)
    90  	reqBuf := svrMsg.Data().(*Args).Request.(binaryReqType)
    91  	seqID, err = GetSeqID(reqBuf)
    92  	test.Assert(t, err == nil, err)
    93  	test.Assert(t, seqID == 1, seqID)
    94  
    95  	var req2 kt.MockTestArgs
    96  	method, seqID2, err2 := rc.Decode(reqBuf, &req2)
    97  	test.Assert(t, err2 == nil, err)
    98  	test.Assert(t, seqID2 == 1, seqID)
    99  	test.Assert(t, method == "mock", method)
   100  }
   101  
   102  func TestBinaryThriftCodecExceptionError(t *testing.T) {
   103  	ctx := context.Background()
   104  	btc := &binaryThriftCodec{thriftCodec}
   105  	cliMsg := &mockMessage{
   106  		RPCInfoFunc: func() rpcinfo.RPCInfo {
   107  			return newEmptyMethodRPCInfo()
   108  		},
   109  		RPCRoleFunc: func() remote.RPCRole {
   110  			return remote.Server
   111  		},
   112  		MessageTypeFunc: func() remote.MessageType {
   113  			return remote.Exception
   114  		},
   115  	}
   116  
   117  	rwbuf := remote.NewReaderWriterBuffer(1024)
   118  	// test data is empty
   119  	err := btc.Marshal(ctx, cliMsg, rwbuf)
   120  	test.Assert(t, err.Error() == "invalid marshal data in rawThriftBinaryCodec: nil")
   121  	cliMsg.DataFunc = func() interface{} {
   122  		return &remote.TransError{}
   123  	}
   124  
   125  	// empty method
   126  	err = btc.Marshal(ctx, cliMsg, rwbuf)
   127  	test.Assert(t, err.Error() == "rawThriftBinaryCodec Marshal exception failed, err: empty methodName in thrift Marshal")
   128  
   129  	cliMsg.RPCInfoFunc = func() rpcinfo.RPCInfo {
   130  		return newMockRPCInfo()
   131  	}
   132  	err = btc.Marshal(ctx, cliMsg, rwbuf)
   133  	test.Assert(t, err == nil)
   134  	err = btc.Unmarshal(ctx, cliMsg, rwbuf)
   135  	test.Assert(t, err.Error() == "unknown application exception")
   136  
   137  	// test server role
   138  	cliMsg.MessageTypeFunc = func() remote.MessageType {
   139  		return remote.Call
   140  	}
   141  	cliMsg.DataFunc = func() interface{} {
   142  		return &Result{
   143  			Success: binaryReqType{},
   144  		}
   145  	}
   146  	err = btc.Marshal(ctx, cliMsg, rwbuf)
   147  	test.Assert(t, err == nil)
   148  }
   149  
   150  func newMockRPCInfo() rpcinfo.RPCInfo {
   151  	c := rpcinfo.NewEndpointInfo("", "", nil, nil)
   152  	s := rpcinfo.NewEndpointInfo("", "", nil, nil)
   153  	ink := rpcinfo.NewInvocation("", "mock")
   154  	ri := rpcinfo.NewRPCInfo(c, s, ink, nil, rpcinfo.NewRPCStats())
   155  	return ri
   156  }
   157  
   158  func newEmptyMethodRPCInfo() rpcinfo.RPCInfo {
   159  	c := rpcinfo.NewEndpointInfo("", "", nil, nil)
   160  	s := rpcinfo.NewEndpointInfo("", "", nil, nil)
   161  	ink := rpcinfo.NewInvocation("", "")
   162  	ri := rpcinfo.NewRPCInfo(c, s, ink, nil, nil)
   163  	return ri
   164  }
   165  
   166  var _ remote.Message = &mockMessage{}
   167  
   168  type mockMessage struct {
   169  	RPCInfoFunc         func() rpcinfo.RPCInfo
   170  	ServiceInfoFunc     func() *serviceinfo.ServiceInfo
   171  	SetServiceInfoFunc  func(svcName, methodName string) (*serviceinfo.ServiceInfo, error)
   172  	DataFunc            func() interface{}
   173  	NewDataFunc         func(method string) (ok bool)
   174  	MessageTypeFunc     func() remote.MessageType
   175  	SetMessageTypeFunc  func(remote.MessageType)
   176  	RPCRoleFunc         func() remote.RPCRole
   177  	PayloadLenFunc      func() int
   178  	SetPayloadLenFunc   func(size int)
   179  	TransInfoFunc       func() remote.TransInfo
   180  	TagsFunc            func() map[string]interface{}
   181  	ProtocolInfoFunc    func() remote.ProtocolInfo
   182  	SetProtocolInfoFunc func(remote.ProtocolInfo)
   183  	PayloadCodecFunc    func() remote.PayloadCodec
   184  	SetPayloadCodecFunc func(pc remote.PayloadCodec)
   185  	RecycleFunc         func()
   186  }
   187  
   188  func (m *mockMessage) RPCInfo() rpcinfo.RPCInfo {
   189  	if m.RPCInfoFunc != nil {
   190  		return m.RPCInfoFunc()
   191  	}
   192  	return nil
   193  }
   194  
   195  func (m *mockMessage) ServiceInfo() (si *serviceinfo.ServiceInfo) {
   196  	if m.ServiceInfoFunc != nil {
   197  		return m.ServiceInfoFunc()
   198  	}
   199  	return
   200  }
   201  
   202  func (m *mockMessage) SpecifyServiceInfo(svcName, methodName string) (si *serviceinfo.ServiceInfo, err error) {
   203  	if m.SetServiceInfoFunc != nil {
   204  		return m.SetServiceInfoFunc(svcName, methodName)
   205  	}
   206  	return nil, nil
   207  }
   208  
   209  func (m *mockMessage) Data() interface{} {
   210  	if m.DataFunc != nil {
   211  		return m.DataFunc()
   212  	}
   213  	return nil
   214  }
   215  
   216  func (m *mockMessage) NewData(method string) (ok bool) {
   217  	if m.NewDataFunc != nil {
   218  		return m.NewDataFunc(method)
   219  	}
   220  	return false
   221  }
   222  
   223  func (m *mockMessage) MessageType() (mt remote.MessageType) {
   224  	if m.MessageTypeFunc != nil {
   225  		return m.MessageTypeFunc()
   226  	}
   227  	return
   228  }
   229  
   230  func (m *mockMessage) SetMessageType(mt remote.MessageType) {
   231  	if m.SetMessageTypeFunc != nil {
   232  		m.SetMessageTypeFunc(mt)
   233  	}
   234  }
   235  
   236  func (m *mockMessage) RPCRole() (r remote.RPCRole) {
   237  	if m.RPCRoleFunc != nil {
   238  		return m.RPCRoleFunc()
   239  	}
   240  	return
   241  }
   242  
   243  func (m *mockMessage) PayloadLen() int {
   244  	if m.PayloadLenFunc != nil {
   245  		return m.PayloadLenFunc()
   246  	}
   247  	return 0
   248  }
   249  
   250  func (m *mockMessage) SetPayloadLen(size int) {
   251  	if m.SetPayloadLenFunc != nil {
   252  		m.SetPayloadLenFunc(size)
   253  	}
   254  }
   255  
   256  func (m *mockMessage) TransInfo() remote.TransInfo {
   257  	if m.TransInfoFunc != nil {
   258  		return m.TransInfoFunc()
   259  	}
   260  	return nil
   261  }
   262  
   263  func (m *mockMessage) Tags() map[string]interface{} {
   264  	if m.TagsFunc != nil {
   265  		return m.TagsFunc()
   266  	}
   267  	return nil
   268  }
   269  
   270  func (m *mockMessage) ProtocolInfo() (pi remote.ProtocolInfo) {
   271  	if m.ProtocolInfoFunc != nil {
   272  		return m.ProtocolInfoFunc()
   273  	}
   274  	return
   275  }
   276  
   277  func (m *mockMessage) SetProtocolInfo(pi remote.ProtocolInfo) {
   278  	if m.SetProtocolInfoFunc != nil {
   279  		m.SetProtocolInfoFunc(pi)
   280  	}
   281  }
   282  
   283  func (m *mockMessage) PayloadCodec() remote.PayloadCodec {
   284  	if m.PayloadCodecFunc != nil {
   285  		return m.PayloadCodecFunc()
   286  	}
   287  	return nil
   288  }
   289  
   290  func (m *mockMessage) SetPayloadCodec(pc remote.PayloadCodec) {
   291  	if m.SetPayloadCodecFunc != nil {
   292  		m.SetPayloadCodecFunc(pc)
   293  	}
   294  }
   295  
   296  func (m *mockMessage) Recycle() {
   297  	if m.RecycleFunc != nil {
   298  		m.RecycleFunc()
   299  	}
   300  }