github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/server_handler_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package nphttp2
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/cloudwego/kitex/internal/test"
    26  	"github.com/cloudwego/kitex/pkg/kerrors"
    27  	"github.com/cloudwego/kitex/pkg/remote"
    28  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
    29  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    30  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    31  	"github.com/cloudwego/kitex/pkg/streaming"
    32  	"github.com/cloudwego/kitex/transport"
    33  )
    34  
    35  func TestServerHandler(t *testing.T) {
    36  	// init
    37  	opt := newMockServerOption()
    38  	msg := newMockNewMessage()
    39  	msg.ProtocolInfoFunc = func() remote.ProtocolInfo {
    40  		return remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Protobuf)
    41  	}
    42  	msg.RPCInfoFunc = func() rpcinfo.RPCInfo {
    43  		return newMockRPCInfo()
    44  	}
    45  	npConn := newMockNpConn(mockAddr0)
    46  	npConn.mockSettingFrame()
    47  	tr, err := newMockServerTransport(npConn)
    48  	test.Assert(t, err == nil, err)
    49  	s := grpc.CreateStream(1, func(i int) {})
    50  	serverConn := newServerConn(tr, s)
    51  	defer serverConn.Close()
    52  
    53  	// test NewTransHandler()
    54  	handler, err := NewSvrTransHandlerFactory().NewTransHandler(opt)
    55  	test.Assert(t, err == nil, err)
    56  
    57  	// test Read()
    58  	// mock grpc encoded msg data into stream recv buffer
    59  	newMockStreamRecvHelloRequest(s)
    60  	ctx, err := handler.Read(context.Background(), serverConn, msg)
    61  	test.Assert(t, ctx != nil, ctx)
    62  	test.Assert(t, err == nil, err)
    63  
    64  	// test write()
    65  	ctx, err = handler.Write(context.Background(), serverConn, msg)
    66  	test.Assert(t, ctx != nil, ctx)
    67  	test.Assert(t, err == nil, err)
    68  
    69  	// test SetInvokeHandleFunc()
    70  	svrHdl := handler.(*svrTransHandler)
    71  	svrHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) {
    72  		return nil
    73  	})
    74  
    75  	// mock a setting frame to pass the server side preface check
    76  	npConn.mockSettingFrame()
    77  	// mock a headerFrame so onRead() can start working
    78  	npConn.mockMetaHeaderFrame()
    79  	go func() {
    80  		// test OnActive()
    81  		ctx, err := handler.OnActive(newMockCtxWithRPCInfo(), npConn)
    82  		test.Assert(t, err == nil, err)
    83  
    84  		handler.OnRead(ctx, npConn)
    85  
    86  		// test OnInactive()
    87  		handler.OnInactive(ctx, npConn)
    88  		test.Assert(t, err == nil, err)
    89  	}()
    90  
    91  	// sleep 50 mills so server can handle metaHeader frame
    92  	time.Sleep(time.Millisecond * 50)
    93  
    94  	// test OnError()
    95  	handler.OnError(context.Background(), context.Canceled, npConn)
    96  
    97  	// test SetPipeline()
    98  	handler.SetPipeline(nil)
    99  }
   100  
   101  type mockStream struct {
   102  	streaming.Stream
   103  	recv func(msg interface{}) error
   104  	send func(msg interface{}) error
   105  }
   106  
   107  func (s *mockStream) RecvMsg(m interface{}) error {
   108  	return s.recv(m)
   109  }
   110  
   111  func (s *mockStream) SendMsg(m interface{}) error {
   112  	return s.send(m)
   113  }
   114  
   115  func Test_invokeStreamUnaryHandler(t *testing.T) {
   116  	t.Run("recv err", func(t *testing.T) {
   117  		expectedErr := errors.New("mock err")
   118  		s := &mockStream{
   119  			recv: func(msg interface{}) error {
   120  				return expectedErr
   121  			},
   122  		}
   123  		var newArgsCalled, newResultCalled bool
   124  		mi := serviceinfo.NewMethodInfo(nil,
   125  			func() interface{} {
   126  				newArgsCalled = true
   127  				return nil
   128  			},
   129  			func() interface{} {
   130  				newResultCalled = true
   131  				return nil
   132  			},
   133  			false,
   134  		)
   135  		hdl := func(ctx context.Context, req, resp interface{}) (err error) {
   136  			return nil
   137  		}
   138  
   139  		err := invokeStreamUnaryHandler(context.Background(), s, mi, hdl, nil)
   140  
   141  		test.Assert(t, err == expectedErr, err)
   142  		test.Assert(t, newArgsCalled)
   143  		test.Assert(t, newResultCalled)
   144  	})
   145  	t.Run("handler err", func(t *testing.T) {
   146  		expectedErr := errors.New("mock err")
   147  		var newArgsCalled, newResultCalled, handlerCalled, recvCalled bool
   148  		s := &mockStream{
   149  			recv: func(msg interface{}) error {
   150  				recvCalled = true
   151  				return nil
   152  			},
   153  		}
   154  		mi := serviceinfo.NewMethodInfo(nil,
   155  			func() interface{} {
   156  				newArgsCalled = true
   157  				return nil
   158  			},
   159  			func() interface{} {
   160  				newResultCalled = true
   161  				return nil
   162  			},
   163  			false,
   164  		)
   165  		hdl := func(ctx context.Context, req, resp interface{}) (err error) {
   166  			handlerCalled = true
   167  			return expectedErr
   168  		}
   169  
   170  		err := invokeStreamUnaryHandler(context.Background(), s, mi, hdl, nil)
   171  
   172  		test.Assert(t, err == expectedErr, err)
   173  		test.Assert(t, recvCalled)
   174  		test.Assert(t, newArgsCalled)
   175  		test.Assert(t, newResultCalled)
   176  		test.Assert(t, handlerCalled)
   177  	})
   178  
   179  	t.Run("biz err", func(t *testing.T) {
   180  		expectedErr := kerrors.NewBizStatusError(100, "mock biz error")
   181  		var newArgsCalled, newResultCalled, handlerCalled, recvCalled, sendCalled bool
   182  		s := &mockStream{
   183  			recv: func(msg interface{}) error {
   184  				recvCalled = true
   185  				return nil
   186  			},
   187  			send: func(msg interface{}) error {
   188  				sendCalled = true
   189  				return nil
   190  			},
   191  		}
   192  		mi := serviceinfo.NewMethodInfo(nil,
   193  			func() interface{} {
   194  				newArgsCalled = true
   195  				return nil
   196  			},
   197  			func() interface{} {
   198  				newResultCalled = true
   199  				return nil
   200  			},
   201  			false,
   202  		)
   203  		hdl := func(ctx context.Context, req, resp interface{}) (err error) {
   204  			handlerCalled = true
   205  			return nil
   206  		}
   207  
   208  		ivk := rpcinfo.NewInvocation("test", "test")
   209  		ivk.SetBizStatusErr(expectedErr)
   210  		ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil)
   211  
   212  		err := invokeStreamUnaryHandler(context.Background(), s, mi, hdl, ri)
   213  
   214  		test.Assert(t, err == nil, err)
   215  		test.Assert(t, recvCalled)
   216  		test.Assert(t, newArgsCalled)
   217  		test.Assert(t, newResultCalled)
   218  		test.Assert(t, handlerCalled)
   219  		test.Assert(t, !sendCalled)
   220  	})
   221  
   222  	t.Run("send err", func(t *testing.T) {
   223  		expectedErr := errors.New("mock err")
   224  		var newArgsCalled, newResultCalled, handlerCalled, recvCalled, sendCalled bool
   225  		s := &mockStream{
   226  			recv: func(msg interface{}) error {
   227  				recvCalled = true
   228  				return nil
   229  			},
   230  			send: func(msg interface{}) error {
   231  				sendCalled = true
   232  				return expectedErr
   233  			},
   234  		}
   235  		mi := serviceinfo.NewMethodInfo(nil,
   236  			func() interface{} {
   237  				newArgsCalled = true
   238  				return nil
   239  			},
   240  			func() interface{} {
   241  				newResultCalled = true
   242  				return nil
   243  			},
   244  			false,
   245  		)
   246  		hdl := func(ctx context.Context, req, resp interface{}) (err error) {
   247  			handlerCalled = true
   248  			return nil
   249  		}
   250  
   251  		ivk := rpcinfo.NewInvocation("test", "test")
   252  		ri := rpcinfo.NewRPCInfo(nil, nil, ivk, nil, nil)
   253  
   254  		err := invokeStreamUnaryHandler(context.Background(), s, mi, hdl, ri)
   255  
   256  		test.Assert(t, err == expectedErr, err)
   257  		test.Assert(t, recvCalled)
   258  		test.Assert(t, newArgsCalled)
   259  		test.Assert(t, newResultCalled)
   260  		test.Assert(t, handlerCalled)
   261  		test.Assert(t, sendCalled)
   262  	})
   263  }