github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/default_server_handler_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package trans
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"testing"
    24  
    25  	"github.com/golang/mock/gomock"
    26  
    27  	"github.com/cloudwego/kitex/internal/mocks"
    28  	"github.com/cloudwego/kitex/internal/mocks/stats"
    29  	"github.com/cloudwego/kitex/internal/test"
    30  	"github.com/cloudwego/kitex/pkg/kerrors"
    31  	"github.com/cloudwego/kitex/pkg/remote"
    32  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    33  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    34  )
    35  
    36  var (
    37  	svcInfo      = mocks.ServiceInfo()
    38  	svcSearchMap = map[string]*serviceinfo.ServiceInfo{
    39  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod):          svcInfo,
    40  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
    41  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod):     svcInfo,
    42  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod):    svcInfo,
    43  		mocks.MockMethod:          svcInfo,
    44  		mocks.MockExceptionMethod: svcInfo,
    45  		mocks.MockErrorMethod:     svcInfo,
    46  		mocks.MockOnewayMethod:    svcInfo,
    47  	}
    48  )
    49  
    50  func TestDefaultSvrTransHandler(t *testing.T) {
    51  	buf := remote.NewReaderWriterBuffer(1024)
    52  	ext := &MockExtension{
    53  		NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
    54  			return buf
    55  		},
    56  		NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
    57  			return buf
    58  		},
    59  	}
    60  
    61  	tagEncode, tagDecode := 0, 0
    62  	opt := &remote.ServerOption{
    63  		Codec: &MockCodec{
    64  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
    65  				tagEncode++
    66  				test.Assert(t, out == buf)
    67  				return nil
    68  			},
    69  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
    70  				tagDecode++
    71  				test.Assert(t, in == buf)
    72  				return nil
    73  			},
    74  		},
    75  		SvcSearchMap:  svcSearchMap,
    76  		TargetSvcInfo: svcInfo,
    77  	}
    78  
    79  	handler, err := NewDefaultSvrTransHandler(opt, ext)
    80  	test.Assert(t, err == nil)
    81  
    82  	ctx := context.Background()
    83  	conn := &mocks.Conn{}
    84  	msg := &MockMessage{
    85  		RPCInfoFunc: func() rpcinfo.RPCInfo {
    86  			return newMockRPCInfo()
    87  		},
    88  		ServiceInfoFunc: func() *serviceinfo.ServiceInfo {
    89  			return &serviceinfo.ServiceInfo{
    90  				Methods: map[string]serviceinfo.MethodInfo{
    91  					"method": serviceinfo.NewMethodInfo(nil, nil, nil, false),
    92  				},
    93  			}
    94  		},
    95  	}
    96  	ctx, err = handler.Write(ctx, conn, msg)
    97  	test.Assert(t, ctx != nil, ctx)
    98  	test.Assert(t, err == nil, err)
    99  	test.Assert(t, tagEncode == 1, tagEncode)
   100  	test.Assert(t, tagDecode == 0, tagDecode)
   101  
   102  	ctx, err = handler.Read(ctx, conn, msg)
   103  	test.Assert(t, ctx != nil, ctx)
   104  	test.Assert(t, err == nil, err)
   105  	test.Assert(t, tagEncode == 1, tagEncode)
   106  	test.Assert(t, tagDecode == 1, tagDecode)
   107  }
   108  
   109  func TestSvrTransHandlerBizError(t *testing.T) {
   110  	ctrl := gomock.NewController(t)
   111  	defer ctrl.Finish()
   112  
   113  	mockTracer := stats.NewMockTracer(ctrl)
   114  	mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes()
   115  	mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) {
   116  		err := rpcinfo.GetRPCInfo(ctx).Stats().Error()
   117  		test.Assert(t, err != nil)
   118  	}).AnyTimes()
   119  
   120  	buf := remote.NewReaderWriterBuffer(1024)
   121  	ext := &MockExtension{
   122  		NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   123  			return buf
   124  		},
   125  		NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   126  			return buf
   127  		},
   128  	}
   129  
   130  	tracerCtl := &rpcinfo.TraceController{}
   131  	tracerCtl.Append(mockTracer)
   132  	opt := &remote.ServerOption{
   133  		Codec: &MockCodec{
   134  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   135  				return nil
   136  			},
   137  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   138  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
   139  				return nil
   140  			},
   141  		},
   142  		SvcSearchMap:  svcSearchMap,
   143  		TargetSvcInfo: svcInfo,
   144  		TracerCtl:     tracerCtl,
   145  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
   146  			rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr)
   147  			return ri
   148  		},
   149  	}
   150  	ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}),
   151  		rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats())
   152  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   153  
   154  	svrHandler, err := NewDefaultSvrTransHandler(opt, ext)
   155  	pl := remote.NewTransPipeline(svrHandler)
   156  	svrHandler.SetPipeline(pl)
   157  	if setter, ok := svrHandler.(remote.InvokeHandleFuncSetter); ok {
   158  		setter.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
   159  			return kerrors.ErrBiz.WithCause(errors.New("mock"))
   160  		})
   161  	}
   162  	test.Assert(t, err == nil)
   163  	err = svrHandler.OnRead(ctx, &mocks.Conn{})
   164  	test.Assert(t, err == nil)
   165  }
   166  
   167  func TestSvrTransHandlerReadErr(t *testing.T) {
   168  	ctrl := gomock.NewController(t)
   169  	defer ctrl.Finish()
   170  
   171  	mockTracer := stats.NewMockTracer(ctrl)
   172  	mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes()
   173  	mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) {
   174  		err := rpcinfo.GetRPCInfo(ctx).Stats().Error()
   175  		test.Assert(t, err != nil)
   176  	}).AnyTimes()
   177  
   178  	buf := remote.NewReaderWriterBuffer(1024)
   179  	ext := &MockExtension{
   180  		NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   181  			return buf
   182  		},
   183  		NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   184  			return buf
   185  		},
   186  	}
   187  
   188  	mockErr := errors.New("mock")
   189  	tracerCtl := &rpcinfo.TraceController{}
   190  	tracerCtl.Append(mockTracer)
   191  	opt := &remote.ServerOption{
   192  		Codec: &MockCodec{
   193  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   194  				return nil
   195  			},
   196  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   197  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
   198  				return mockErr
   199  			},
   200  		},
   201  		SvcSearchMap:  svcSearchMap,
   202  		TargetSvcInfo: svcInfo,
   203  		TracerCtl:     tracerCtl,
   204  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
   205  			rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr)
   206  			return ri
   207  		},
   208  	}
   209  	ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}),
   210  		rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats())
   211  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   212  
   213  	svrHandler, err := NewDefaultSvrTransHandler(opt, ext)
   214  	test.Assert(t, err == nil)
   215  	pl := remote.NewTransPipeline(svrHandler)
   216  	svrHandler.SetPipeline(pl)
   217  	err = svrHandler.OnRead(ctx, &mocks.Conn{})
   218  	test.Assert(t, err != nil)
   219  	test.Assert(t, errors.Is(err, mockErr))
   220  }
   221  
   222  func TestSvrTransHandlerOnReadHeartbeat(t *testing.T) {
   223  	ctrl := gomock.NewController(t)
   224  	defer ctrl.Finish()
   225  
   226  	mockTracer := stats.NewMockTracer(ctrl)
   227  	mockTracer.EXPECT().Start(gomock.Any()).DoAndReturn(func(ctx context.Context) context.Context { return ctx }).AnyTimes()
   228  	mockTracer.EXPECT().Finish(gomock.Any()).DoAndReturn(func(ctx context.Context) {
   229  		err := rpcinfo.GetRPCInfo(ctx).Stats().Error()
   230  		test.Assert(t, err == nil)
   231  	}).AnyTimes()
   232  
   233  	buf := remote.NewReaderWriterBuffer(1024)
   234  	ext := &MockExtension{
   235  		NewWriteByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   236  			return buf
   237  		},
   238  		NewReadByteBufferFunc: func(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   239  			return buf
   240  		},
   241  	}
   242  
   243  	tracerCtl := &rpcinfo.TraceController{}
   244  	tracerCtl.Append(mockTracer)
   245  	opt := &remote.ServerOption{
   246  		Codec: &MockCodec{
   247  			EncodeFunc: func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error {
   248  				if msg.MessageType() != remote.Heartbeat {
   249  					return errors.New("response is not of MessageType Heartbeat")
   250  				}
   251  				return nil
   252  			},
   253  			DecodeFunc: func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error {
   254  				msg.SetMessageType(remote.Heartbeat)
   255  				msg.SpecifyServiceInfo(mocks.MockServiceName, mocks.MockMethod)
   256  				return nil
   257  			},
   258  		},
   259  		SvcSearchMap:  svcSearchMap,
   260  		TargetSvcInfo: svcInfo,
   261  		TracerCtl:     tracerCtl,
   262  		InitOrResetRPCInfoFunc: func(ri rpcinfo.RPCInfo, addr net.Addr) rpcinfo.RPCInfo {
   263  			rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(addr)
   264  			return ri
   265  		},
   266  	}
   267  	ri := rpcinfo.NewRPCInfo(rpcinfo.EmptyEndpointInfo(), rpcinfo.FromBasicInfo(&rpcinfo.EndpointBasicInfo{}),
   268  		rpcinfo.NewInvocation("", mocks.MockMethod), nil, rpcinfo.NewRPCStats())
   269  	ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   270  
   271  	svrHandler, err := NewDefaultSvrTransHandler(opt, ext)
   272  	test.Assert(t, err == nil)
   273  	pl := remote.NewTransPipeline(svrHandler)
   274  	svrHandler.SetPipeline(pl)
   275  	err = svrHandler.OnRead(ctx, &mocks.Conn{})
   276  	test.Assert(t, err == nil)
   277  }