github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/detection/server_handler_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package detection
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"testing"
    25  
    26  	"github.com/golang/mock/gomock"
    27  
    28  	"github.com/cloudwego/kitex/internal/mocks"
    29  	mocksklog "github.com/cloudwego/kitex/internal/mocks/klog"
    30  	npmocks "github.com/cloudwego/kitex/internal/mocks/netpoll"
    31  	remote_mocks "github.com/cloudwego/kitex/internal/mocks/remote"
    32  	"github.com/cloudwego/kitex/internal/test"
    33  	"github.com/cloudwego/kitex/pkg/klog"
    34  	"github.com/cloudwego/kitex/pkg/remote"
    35  	"github.com/cloudwego/kitex/pkg/remote/codec"
    36  	"github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
    37  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
    38  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
    39  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    40  	"github.com/cloudwego/kitex/pkg/utils"
    41  )
    42  
    43  var (
    44  	prefaceReadAtMost = func() int {
    45  		// min(len(ClientPreface), len(flagBuf))
    46  		// len(flagBuf) = 2 * codec.Size32
    47  		if 2*codec.Size32 < grpc.ClientPrefaceLen {
    48  			return 2 * codec.Size32
    49  		}
    50  		return grpc.ClientPrefaceLen
    51  	}()
    52  	svcInfo      = mocks.ServiceInfo()
    53  	svcSearchMap = map[string]*serviceinfo.ServiceInfo{
    54  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod):          svcInfo,
    55  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
    56  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod):     svcInfo,
    57  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod):    svcInfo,
    58  		mocks.MockMethod:          svcInfo,
    59  		mocks.MockExceptionMethod: svcInfo,
    60  		mocks.MockErrorMethod:     svcInfo,
    61  		mocks.MockOnewayMethod:    svcInfo,
    62  	}
    63  )
    64  
    65  func TestServerHandlerCall(t *testing.T) {
    66  	transHdler, _ := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{
    67  		SvcSearchMap:  svcSearchMap,
    68  		TargetSvcInfo: svcInfo,
    69  	})
    70  
    71  	ctrl := gomock.NewController(t)
    72  	defer ctrl.Finish()
    73  
    74  	npConn := npmocks.NewMockConnection(ctrl)
    75  	npReader := npmocks.NewMockReader(ctrl)
    76  	hdl := remote_mocks.NewMockServerTransHandler(ctrl)
    77  
    78  	errOnActive := errors.New("mock on active error")
    79  	errOnRead := errors.New("mock on read error")
    80  
    81  	triggerReadErr := false
    82  	triggerActiveErr := false
    83  
    84  	hdl.EXPECT().OnActive(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) (context.Context, error) {
    85  		if triggerActiveErr {
    86  			return ctx, errOnActive
    87  		}
    88  		return ctx, nil
    89  	}).AnyTimes()
    90  	hdl.EXPECT().OnRead(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, conn net.Conn) error {
    91  		if triggerReadErr {
    92  			return errOnRead
    93  		}
    94  		return nil
    95  	}).AnyTimes()
    96  	hdl.EXPECT().OnInactive(gomock.Any(), gomock.Any()).AnyTimes()
    97  	hdl.EXPECT().OnError(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, err error, conn net.Conn) {
    98  		if conn != nil {
    99  			klog.CtxErrorf(ctx, "KITEX: processing error, remoteAddr=%v, error=%s", conn.RemoteAddr(), err.Error())
   100  		} else {
   101  			klog.CtxErrorf(ctx, "KITEX: processing error, error=%s", err.Error())
   102  		}
   103  	}).AnyTimes()
   104  
   105  	npReader.EXPECT().Peek(prefaceReadAtMost).Return([]byte("connection prefix bytes"), nil).AnyTimes()
   106  	npConn.EXPECT().Reader().Return(npReader).AnyTimes()
   107  	npConn.EXPECT().RemoteAddr().Return(nil).AnyTimes()
   108  
   109  	transHdler.(*svrTransHandler).defaultHandler = hdl
   110  
   111  	// case1 successful call: onActive() and onRead() all success
   112  	triggerActiveErr = false
   113  	triggerReadErr = false
   114  	err := mockCall(transHdler, npConn)
   115  	test.Assert(t, err == nil, err)
   116  
   117  	// case2 onActive failed: onActive() err and close conn
   118  	triggerActiveErr = true
   119  	triggerReadErr = false
   120  	err = mockCall(transHdler, npConn)
   121  	test.Assert(t, err == errOnActive, err)
   122  
   123  	// case3 onRead failed: onRead() err and close conn
   124  	triggerActiveErr = false
   125  	triggerReadErr = true
   126  	err = mockCall(transHdler, npConn)
   127  	test.Assert(t, err == errOnRead, err)
   128  }
   129  
   130  func TestOnError(t *testing.T) {
   131  	ctrl := gomock.NewController(t)
   132  	defer func() {
   133  		klog.SetLogger(klog.DefaultLogger())
   134  		ctrl.Finish()
   135  	}()
   136  	transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{
   137  		SvcSearchMap:  svcSearchMap,
   138  		TargetSvcInfo: svcInfo,
   139  	})
   140  	test.Assert(t, err == nil)
   141  
   142  	mocklogger := mocksklog.NewMockFullLogger(ctrl)
   143  	klog.SetLogger(mocklogger)
   144  
   145  	var errMsg string
   146  	mocklogger.EXPECT().CtxErrorf(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(ctx context.Context, format string, v ...interface{}) {
   147  		errMsg = fmt.Sprintf(format, v...)
   148  	})
   149  	transHdler.OnError(context.Background(), errors.New("mock error"), nil)
   150  	test.Assert(t, errMsg == "KITEX: processing error, error=mock error", errMsg)
   151  
   152  	conn := &mocks.Conn{
   153  		RemoteAddrFunc: func() (r net.Addr) {
   154  			return utils.NewNetAddr("mock", "mock")
   155  		},
   156  	}
   157  	mocklogger.EXPECT().CtxErrorf(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(ctx context.Context, format string, v ...interface{}) {
   158  		errMsg = fmt.Sprintf(format, v...)
   159  	})
   160  	transHdler.OnError(context.Background(), errors.New("mock error"), conn)
   161  	test.Assert(t, errMsg == "KITEX: processing error, remoteAddr=mock, error=mock error", errMsg)
   162  }
   163  
   164  // TestOnInactive covers onInactive() codes to check panic
   165  func TestOnInactive(t *testing.T) {
   166  	transHdler, err := NewSvrTransHandlerFactory(netpoll.NewSvrTransHandlerFactory(), nphttp2.NewSvrTransHandlerFactory()).NewTransHandler(&remote.ServerOption{
   167  		SvcSearchMap:  svcSearchMap,
   168  		TargetSvcInfo: svcInfo,
   169  	})
   170  	test.Assert(t, err == nil)
   171  
   172  	conn := &mocks.Conn{
   173  		RemoteAddrFunc: func() (r net.Addr) {
   174  			return utils.NewNetAddr("mock", "mock")
   175  		},
   176  	}
   177  
   178  	// case1 test noopHandler onInactive()
   179  	transHdler.OnInactive(context.Background(), conn)
   180  
   181  	// mock a ctx and set handlerKey
   182  	subHandler := &mocks.MockSvrTransHandler{}
   183  	subHandlerCtx := context.WithValue(
   184  		context.Background(),
   185  		handlerKey{},
   186  		&handlerWrapper{
   187  			handler: subHandler,
   188  		},
   189  	)
   190  
   191  	ctx := context.WithValue(
   192  		context.Background(),
   193  		handlerKey{},
   194  		&handlerWrapper{
   195  			ctx: subHandlerCtx,
   196  		},
   197  	)
   198  	// case2 test subHandler onInactive()
   199  	transHdler.OnInactive(ctx, conn)
   200  }
   201  
   202  // mockCall mocks how detection transHdlr processing with incoming requests
   203  func mockCall(transHdlr remote.ServerTransHandler, conn net.Conn) (err error) {
   204  	ctx := context.Background()
   205  	// do onConnActive
   206  	ctxWithHandler, err := transHdlr.OnActive(ctx, conn)
   207  	// onActive failed, such as connections limitation
   208  	if err != nil {
   209  		transHdlr.OnError(ctx, err, conn)
   210  		transHdlr.OnInactive(ctx, conn)
   211  		return
   212  	}
   213  	// do onConnRead
   214  	err = transHdlr.OnRead(ctxWithHandler, conn)
   215  	if err != nil {
   216  		transHdlr.OnError(ctxWithHandler, err, conn)
   217  		transHdlr.OnInactive(ctxWithHandler, conn)
   218  		return
   219  	}
   220  	// do onConnInactive
   221  	transHdlr.OnInactive(ctxWithHandler, conn)
   222  	return
   223  }