github.com/cloudwego/kitex@v0.9.0/server/server_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package server
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"os"
    24  	"reflect"
    25  	"strings"
    26  	"sync"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/bytedance/gopkg/cloud/metainfo"
    32  	"github.com/cloudwego/localsession"
    33  	"github.com/cloudwego/localsession/backup"
    34  	"github.com/golang/mock/gomock"
    35  
    36  	"github.com/cloudwego/kitex/internal/mocks"
    37  	mockslimiter "github.com/cloudwego/kitex/internal/mocks/limiter"
    38  	internal_server "github.com/cloudwego/kitex/internal/server"
    39  	"github.com/cloudwego/kitex/internal/test"
    40  	"github.com/cloudwego/kitex/pkg/endpoint"
    41  	"github.com/cloudwego/kitex/pkg/limit"
    42  	"github.com/cloudwego/kitex/pkg/limiter"
    43  	"github.com/cloudwego/kitex/pkg/registry"
    44  	"github.com/cloudwego/kitex/pkg/remote"
    45  	"github.com/cloudwego/kitex/pkg/remote/bound"
    46  	"github.com/cloudwego/kitex/pkg/remote/trans"
    47  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
    48  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    49  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    50  	"github.com/cloudwego/kitex/pkg/stats"
    51  	"github.com/cloudwego/kitex/pkg/transmeta"
    52  	"github.com/cloudwego/kitex/pkg/utils"
    53  	"github.com/cloudwego/kitex/transport"
    54  )
    55  
    56  var (
    57  	svcInfo      = mocks.ServiceInfo()
    58  	svcSearchMap = map[string]*serviceinfo.ServiceInfo{
    59  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockMethod):          svcInfo,
    60  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockExceptionMethod): svcInfo,
    61  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockErrorMethod):     svcInfo,
    62  		remote.BuildMultiServiceKey(mocks.MockServiceName, mocks.MockOnewayMethod):    svcInfo,
    63  		mocks.MockMethod:          svcInfo,
    64  		mocks.MockExceptionMethod: svcInfo,
    65  		mocks.MockErrorMethod:     svcInfo,
    66  		mocks.MockOnewayMethod:    svcInfo,
    67  	}
    68  )
    69  
    70  func TestServerRun(t *testing.T) {
    71  	var opts []Option
    72  	opts = append(opts, WithMetaHandler(noopMetahandler{}))
    73  	svr := NewServer(opts...)
    74  
    75  	time.AfterFunc(time.Millisecond*500, func() {
    76  		err := svr.Stop()
    77  		test.Assert(t, err == nil, err)
    78  	})
    79  
    80  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
    81  	test.Assert(t, err == nil)
    82  
    83  	var runHook int32
    84  	var shutdownHook int32
    85  	RegisterStartHook(func() {
    86  		atomic.AddInt32(&runHook, 1)
    87  	})
    88  	RegisterShutdownHook(func() {
    89  		atomic.AddInt32(&shutdownHook, 1)
    90  	})
    91  	err = svr.Run()
    92  	test.Assert(t, err == nil, err)
    93  
    94  	test.Assert(t, atomic.LoadInt32(&runHook) == 1)
    95  	test.Assert(t, atomic.LoadInt32(&shutdownHook) == 1)
    96  }
    97  
    98  func TestReusePortServerRun(t *testing.T) {
    99  	hostPort := test.GetLocalAddress()
   100  	addr, _ := net.ResolveTCPAddr("tcp", hostPort)
   101  	var opts []Option
   102  	opts = append(opts, WithReusePort(true))
   103  	opts = append(opts, WithServiceAddr(addr), WithExitWaitTime(time.Microsecond*10))
   104  
   105  	var wg sync.WaitGroup
   106  	for i := 0; i < 8; i++ {
   107  		wg.Add(2)
   108  		go func() {
   109  			defer wg.Done()
   110  			svr := NewServer(opts...)
   111  			time.AfterFunc(time.Millisecond*100, func() {
   112  				defer wg.Done()
   113  				err := svr.Stop()
   114  				test.Assert(t, err == nil, err)
   115  			})
   116  			err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   117  			test.Assert(t, err == nil)
   118  			err = svr.Run()
   119  			test.Assert(t, err == nil, err)
   120  		}()
   121  	}
   122  	wg.Wait()
   123  }
   124  
   125  func TestInitOrResetRPCInfo(t *testing.T) {
   126  	var opts []Option
   127  	rwTimeout := time.Millisecond
   128  	opts = append(opts, WithReadWriteTimeout(rwTimeout))
   129  	svr := &server{
   130  		opt:  internal_server.NewOptions(opts),
   131  		svcs: newServices(),
   132  	}
   133  	svr.init()
   134  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   135  	test.Assert(t, err == nil)
   136  
   137  	remoteAddr := utils.NewNetAddr("tcp", "to")
   138  	conn := new(mocks.Conn)
   139  	conn.RemoteAddrFunc = func() (r net.Addr) {
   140  		return remoteAddr
   141  	}
   142  	rpcInfoInitFunc := svr.initOrResetRPCInfoFunc()
   143  	ri := rpcInfoInitFunc(nil, conn.RemoteAddr())
   144  	test.Assert(t, ri != nil)
   145  	test.Assert(t, ri.From().Address().String() == remoteAddr.String())
   146  	test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout)
   147  
   148  	// modify rpcinfo
   149  	fi := rpcinfo.AsMutableEndpointInfo(ri.From())
   150  	fi.SetServiceName("mock service")
   151  	fi.SetMethod("mock method")
   152  	fi.SetTag("key", "value")
   153  
   154  	ti := rpcinfo.AsMutableEndpointInfo(ri.To())
   155  	ti.SetServiceName("mock service")
   156  	ti.SetMethod("mock method")
   157  	ti.SetTag("key", "value")
   158  
   159  	if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
   160  		setter.SetSeqID(123)
   161  		setter.SetPackageName("mock package")
   162  		setter.SetServiceName("mock service")
   163  		setter.SetMethodName("mock method")
   164  	}
   165  
   166  	mc := rpcinfo.AsMutableRPCConfig(ri.Config())
   167  	mc.SetTransportProtocol(transport.TTHeader)
   168  	mc.SetConnectTimeout(10 * time.Second)
   169  	mc.SetRPCTimeout(20 * time.Second)
   170  	mc.SetInteractionMode(rpcinfo.Streaming)
   171  	mc.SetIOBufferSize(1024)
   172  	mc.SetReadWriteTimeout(30 * time.Second)
   173  
   174  	rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   175  	rpcStats.SetRecvSize(1024)
   176  	rpcStats.SetSendSize(1024)
   177  	rpcStats.SetPanicked(errors.New("panic"))
   178  	rpcStats.SetError(errors.New("err"))
   179  	rpcStats.SetLevel(stats.LevelDetailed)
   180  
   181  	// check setting
   182  	test.Assert(t, ri.From().ServiceName() == "mock service")
   183  	test.Assert(t, ri.From().Method() == "mock method")
   184  	value, exist := ri.From().Tag("key")
   185  	test.Assert(t, exist && value == "value")
   186  
   187  	test.Assert(t, ri.To().ServiceName() == "mock service")
   188  	test.Assert(t, ri.To().Method() == "mock method")
   189  	value, exist = ri.To().Tag("key")
   190  	test.Assert(t, exist && value == "value")
   191  
   192  	test.Assert(t, ri.Invocation().SeqID() == 123)
   193  	test.Assert(t, ri.Invocation().PackageName() == "mock package")
   194  	test.Assert(t, ri.Invocation().ServiceName() == "mock service")
   195  	test.Assert(t, ri.Invocation().MethodName() == "mock method")
   196  
   197  	test.Assert(t, ri.Config().TransportProtocol() == transport.TTHeader)
   198  	test.Assert(t, ri.Config().ConnectTimeout() == 10*time.Second)
   199  	test.Assert(t, ri.Config().RPCTimeout() == 20*time.Second)
   200  	test.Assert(t, ri.Config().InteractionMode() == rpcinfo.Streaming)
   201  	test.Assert(t, ri.Config().IOBufferSize() == 1024)
   202  	test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout)
   203  
   204  	test.Assert(t, ri.Stats().RecvSize() == 1024)
   205  	test.Assert(t, ri.Stats().SendSize() == 1024)
   206  	hasPanicked, _ := ri.Stats().Panicked()
   207  	test.Assert(t, hasPanicked)
   208  	test.Assert(t, ri.Stats().Error() != nil)
   209  	test.Assert(t, ri.Stats().Level() == stats.LevelDetailed)
   210  
   211  	// test reset
   212  	pOld := reflect.ValueOf(ri).Pointer()
   213  	ri = rpcInfoInitFunc(ri, conn.RemoteAddr())
   214  	pNew := reflect.ValueOf(ri).Pointer()
   215  	test.Assert(t, pOld == pNew, pOld, pNew)
   216  	test.Assert(t, ri.From().Address().String() == remoteAddr.String())
   217  	test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout)
   218  
   219  	test.Assert(t, ri.From().ServiceName() == "")
   220  	test.Assert(t, ri.From().Method() == "")
   221  	_, exist = ri.From().Tag("key")
   222  	test.Assert(t, !exist)
   223  
   224  	test.Assert(t, ri.To().ServiceName() == "")
   225  	test.Assert(t, ri.To().Method() == "")
   226  	_, exist = ri.To().Tag("key")
   227  	test.Assert(t, !exist)
   228  
   229  	test.Assert(t, ri.Invocation().SeqID() == 0)
   230  	test.Assert(t, ri.Invocation().PackageName() == "")
   231  	test.Assert(t, ri.Invocation().ServiceName() == "")
   232  	test.Assert(t, ri.Invocation().MethodName() == "")
   233  
   234  	test.Assert(t, ri.Config().TransportProtocol() == 0)
   235  	test.Assert(t, ri.Config().ConnectTimeout() == 50*time.Millisecond)
   236  	test.Assert(t, ri.Config().RPCTimeout() == 0)
   237  	test.Assert(t, ri.Config().InteractionMode() == rpcinfo.PingPong)
   238  	test.Assert(t, ri.Config().IOBufferSize() == 4096)
   239  	test.Assert(t, ri.Config().ReadWriteTimeout() == rwTimeout)
   240  
   241  	test.Assert(t, ri.Stats().RecvSize() == 0)
   242  	test.Assert(t, ri.Stats().SendSize() == 0)
   243  	_, panicked := ri.Stats().Panicked()
   244  	test.Assert(t, panicked == nil)
   245  	test.Assert(t, ri.Stats().Error() == nil)
   246  	test.Assert(t, ri.Stats().Level() == 0)
   247  
   248  	// test reset after rpcInfoPool is disabled
   249  	t.Run("reset with pool disabled", func(t *testing.T) {
   250  		backupState := rpcinfo.PoolEnabled()
   251  		defer rpcinfo.EnablePool(backupState)
   252  		rpcinfo.EnablePool(false)
   253  
   254  		riNew := rpcInfoInitFunc(ri, conn.RemoteAddr())
   255  		pOld, pNew := reflect.ValueOf(ri).Pointer(), reflect.ValueOf(riNew).Pointer()
   256  		test.Assert(t, pOld != pNew, pOld, pNew)
   257  	})
   258  }
   259  
   260  func TestServiceRegisterFailed(t *testing.T) {
   261  	mockRegErr := errors.New("mock register error")
   262  	var rCount int
   263  	var drCount int
   264  	mockRegistry := MockRegistry{
   265  		RegisterFunc: func(info *registry.Info) error {
   266  			rCount++
   267  			return mockRegErr
   268  		},
   269  		DeregisterFunc: func(info *registry.Info) error {
   270  			drCount++
   271  			return nil
   272  		},
   273  	}
   274  	var opts []Option
   275  	opts = append(opts, WithRegistry(mockRegistry))
   276  	svr := NewServer(opts...)
   277  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   278  	test.Assert(t, err == nil)
   279  
   280  	err = svr.Run()
   281  	test.Assert(t, err != nil)
   282  	test.Assert(t, strings.Contains(err.Error(), mockRegErr.Error()))
   283  	test.Assert(t, drCount == 1)
   284  }
   285  
   286  func TestServiceDeregisterFailed(t *testing.T) {
   287  	mockDeregErr := errors.New("mock deregister error")
   288  	var rCount int
   289  	var drCount int
   290  	mockRegistry := MockRegistry{
   291  		RegisterFunc: func(info *registry.Info) error {
   292  			rCount++
   293  			return nil
   294  		},
   295  		DeregisterFunc: func(info *registry.Info) error {
   296  			drCount++
   297  			return mockDeregErr
   298  		},
   299  	}
   300  	var opts []Option
   301  	opts = append(opts, WithRegistry(mockRegistry))
   302  	svr := NewServer(opts...)
   303  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   304  	test.Assert(t, err == nil)
   305  
   306  	time.AfterFunc(2000*time.Millisecond, func() {
   307  		err := svr.Stop()
   308  		test.Assert(t, strings.Contains(err.Error(), mockDeregErr.Error()))
   309  	})
   310  	err = svr.Run()
   311  	test.Assert(t, err == nil, err)
   312  	test.Assert(t, rCount == 1)
   313  }
   314  
   315  func TestServiceRegistryInfo(t *testing.T) {
   316  	registryInfo := &registry.Info{
   317  		Weight: 100,
   318  		Tags:   map[string]string{"aa": "bb"},
   319  	}
   320  	checkInfo := func(info *registry.Info) {
   321  		test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec)
   322  		test.Assert(t, info.Weight == registryInfo.Weight, info.Addr)
   323  		test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr)
   324  		test.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags)
   325  		test.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags)
   326  	}
   327  	var rCount int
   328  	var drCount int
   329  	mockRegistry := MockRegistry{
   330  		RegisterFunc: func(info *registry.Info) error {
   331  			checkInfo(info)
   332  			rCount++
   333  			return nil
   334  		},
   335  		DeregisterFunc: func(info *registry.Info) error {
   336  			checkInfo(info)
   337  			drCount++
   338  			return nil
   339  		},
   340  	}
   341  	var opts []Option
   342  	opts = append(opts, WithRegistry(mockRegistry))
   343  	opts = append(opts, WithRegistryInfo(registryInfo))
   344  	svr := NewServer(opts...)
   345  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   346  	test.Assert(t, err == nil)
   347  
   348  	time.AfterFunc(2000*time.Millisecond, func() {
   349  		err := svr.Stop()
   350  		test.Assert(t, err == nil, err)
   351  	})
   352  	err = svr.Run()
   353  	test.Assert(t, err == nil, err)
   354  	test.Assert(t, rCount == 1)
   355  	test.Assert(t, drCount == 1)
   356  }
   357  
   358  func TestServiceRegistryNoInitInfo(t *testing.T) {
   359  	checkInfo := func(info *registry.Info) {
   360  		test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec)
   361  		test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr)
   362  	}
   363  	var rCount int
   364  	var drCount int
   365  	mockRegistry := MockRegistry{
   366  		RegisterFunc: func(info *registry.Info) error {
   367  			checkInfo(info)
   368  			rCount++
   369  			return nil
   370  		},
   371  		DeregisterFunc: func(info *registry.Info) error {
   372  			checkInfo(info)
   373  			drCount++
   374  			return nil
   375  		},
   376  	}
   377  	var opts []Option
   378  	opts = append(opts, WithRegistry(mockRegistry))
   379  	svr := NewServer(opts...)
   380  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   381  	test.Assert(t, err == nil)
   382  
   383  	time.AfterFunc(2000*time.Millisecond, func() {
   384  		err := svr.Stop()
   385  		test.Assert(t, err == nil, err)
   386  	})
   387  	err = svr.Run()
   388  	test.Assert(t, err == nil, err)
   389  	test.Assert(t, rCount == 1)
   390  	test.Assert(t, drCount == 1)
   391  }
   392  
   393  // TestServiceRegistryInfoWithNilTags is to check the Tags val. If Tags of RegistryInfo is nil,
   394  // the Tags of ServerBasicInfo will be assigned to Tags of RegistryInfo
   395  func TestServiceRegistryInfoWithNilTags(t *testing.T) {
   396  	registryInfo := &registry.Info{
   397  		Weight: 100,
   398  	}
   399  	checkInfo := func(info *registry.Info) {
   400  		test.Assert(t, info.PayloadCodec == serviceinfo.Thrift.String(), info.PayloadCodec)
   401  		test.Assert(t, info.Addr.String() == "[::]:8888", info.Addr)
   402  		test.Assert(t, info.Weight == registryInfo.Weight, info.Weight)
   403  		test.Assert(t, info.Tags["aa"] == "bb", info.Tags)
   404  	}
   405  	var rCount int
   406  	var drCount int
   407  	mockRegistry := MockRegistry{
   408  		RegisterFunc: func(info *registry.Info) error {
   409  			checkInfo(info)
   410  			rCount++
   411  			return nil
   412  		},
   413  		DeregisterFunc: func(info *registry.Info) error {
   414  			checkInfo(info)
   415  			drCount++
   416  			return nil
   417  		},
   418  	}
   419  	var opts []Option
   420  	opts = append(opts, WithRegistry(mockRegistry))
   421  	opts = append(opts, WithRegistryInfo(registryInfo))
   422  	opts = append(opts, WithServerBasicInfo(&rpcinfo.EndpointBasicInfo{
   423  		Tags: map[string]string{"aa": "bb"},
   424  	}))
   425  	svr := NewServer(opts...)
   426  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   427  	test.Assert(t, err == nil)
   428  
   429  	time.AfterFunc(2000*time.Millisecond, func() {
   430  		err := svr.Stop()
   431  		test.Assert(t, err == nil, err)
   432  	})
   433  	err = svr.Run()
   434  	test.Assert(t, err == nil, err)
   435  	test.Assert(t, rCount == 1)
   436  	test.Assert(t, drCount == 1)
   437  }
   438  
   439  func TestGRPCServerMultipleServices(t *testing.T) {
   440  	var opts []Option
   441  	opts = append(opts, withGRPCTransport())
   442  	svr := NewServer(opts...)
   443  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   444  	test.Assert(t, err == nil)
   445  	err = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler())
   446  	test.Assert(t, err == nil)
   447  	test.DeepEqual(t, svr.GetServiceInfos()[mocks.MockMethod], mocks.ServiceInfo())
   448  	test.DeepEqual(t, svr.GetServiceInfos()[mocks.Mock2Method], mocks.Service2Info())
   449  	time.AfterFunc(1000*time.Millisecond, func() {
   450  		err := svr.Stop()
   451  		test.Assert(t, err == nil, err)
   452  	})
   453  	err = svr.Run()
   454  	test.Assert(t, err == nil, err)
   455  }
   456  
   457  func TestServerBoundHandler(t *testing.T) {
   458  	ctrl := gomock.NewController(t)
   459  	defer ctrl.Finish()
   460  
   461  	interval := time.Millisecond * 100
   462  	cases := []struct {
   463  		opts          []Option
   464  		wantInbounds  []remote.InboundHandler
   465  		wantOutbounds []remote.OutboundHandler
   466  	}{
   467  		{
   468  			opts: []Option{
   469  				WithLimit(&limit.Option{
   470  					MaxConnections: 1000,
   471  					MaxQPS:         10000,
   472  				}),
   473  				WithMetaHandler(noopMetahandler{}),
   474  			},
   475  			wantInbounds: []remote.InboundHandler{
   476  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}),
   477  				bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, false),
   478  			},
   479  			wantOutbounds: []remote.OutboundHandler{
   480  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}),
   481  			},
   482  		},
   483  		{
   484  			opts: []Option{
   485  				WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)),
   486  			},
   487  			wantInbounds: []remote.InboundHandler{
   488  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   489  				bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), &limiter.DummyRateLimiter{}, nil, false),
   490  			},
   491  			wantOutbounds: []remote.OutboundHandler{
   492  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   493  			},
   494  		},
   495  		{
   496  			opts: []Option{
   497  				WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)),
   498  			},
   499  			wantInbounds: []remote.InboundHandler{
   500  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   501  				bound.NewServerLimiterHandler(&limiter.DummyConcurrencyLimiter{}, mockslimiter.NewMockRateLimiter(ctrl), nil, true),
   502  			},
   503  			wantOutbounds: []remote.OutboundHandler{
   504  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   505  			},
   506  		},
   507  		{
   508  			opts: []Option{
   509  				WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)),
   510  				WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)),
   511  			},
   512  			wantInbounds: []remote.InboundHandler{
   513  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   514  				bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true),
   515  			},
   516  			wantOutbounds: []remote.OutboundHandler{
   517  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   518  			},
   519  		},
   520  		{
   521  			opts: []Option{
   522  				WithLimit(&limit.Option{
   523  					MaxConnections: 1000,
   524  					MaxQPS:         10000,
   525  				}),
   526  				WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)),
   527  			},
   528  			wantInbounds: []remote.InboundHandler{
   529  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   530  				bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), limiter.NewQPSLimiter(interval, 10000), nil, false),
   531  			},
   532  			wantOutbounds: []remote.OutboundHandler{
   533  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   534  			},
   535  		},
   536  		{
   537  			opts: []Option{
   538  				WithLimit(&limit.Option{
   539  					MaxConnections: 1000,
   540  					MaxQPS:         10000,
   541  				}),
   542  				WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)),
   543  				WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)),
   544  			},
   545  			wantInbounds: []remote.InboundHandler{
   546  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   547  				bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true),
   548  			},
   549  			wantOutbounds: []remote.OutboundHandler{
   550  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   551  			},
   552  		},
   553  		{
   554  			opts: []Option{
   555  				WithLimit(&limit.Option{
   556  					MaxConnections: 1000,
   557  					MaxQPS:         10000,
   558  				}),
   559  				WithConnectionLimiter(mockslimiter.NewMockConcurrencyLimiter(ctrl)),
   560  				WithQPSLimiter(mockslimiter.NewMockRateLimiter(ctrl)),
   561  				WithMuxTransport(),
   562  			},
   563  			wantInbounds: []remote.InboundHandler{
   564  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   565  				bound.NewServerLimiterHandler(mockslimiter.NewMockConcurrencyLimiter(ctrl), mockslimiter.NewMockRateLimiter(ctrl), nil, true),
   566  			},
   567  			wantOutbounds: []remote.OutboundHandler{
   568  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   569  			},
   570  		},
   571  		{
   572  			opts: []Option{
   573  				WithLimit(&limit.Option{
   574  					MaxConnections: 1000,
   575  					MaxQPS:         10000,
   576  				}),
   577  				WithMuxTransport(),
   578  			},
   579  			wantInbounds: []remote.InboundHandler{
   580  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   581  				bound.NewServerLimiterHandler(limiter.NewConnectionLimiter(1000), limiter.NewQPSLimiter(interval, 10000), nil, true),
   582  			},
   583  			wantOutbounds: []remote.OutboundHandler{
   584  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler}),
   585  			},
   586  		},
   587  		{
   588  			opts: []Option{
   589  				WithMetaHandler(noopMetahandler{}),
   590  			},
   591  			wantInbounds: []remote.InboundHandler{
   592  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}),
   593  			},
   594  			wantOutbounds: []remote.OutboundHandler{
   595  				bound.NewTransMetaHandler([]remote.MetaHandler{transmeta.MetainfoServerHandler, noopMetahandler{}}),
   596  			},
   597  		},
   598  	}
   599  	for _, tcase := range cases {
   600  		opts := append(tcase.opts, WithExitWaitTime(time.Millisecond*10))
   601  		svr := NewServer(opts...)
   602  
   603  		time.AfterFunc(100*time.Millisecond, func() {
   604  			err := svr.Stop()
   605  			test.Assert(t, err == nil, err)
   606  		})
   607  		err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   608  		test.Assert(t, err == nil)
   609  		err = svr.Run()
   610  		test.Assert(t, err == nil, err)
   611  
   612  		iSvr := svr.(*server)
   613  		test.Assert(t, inboundDeepEqual(iSvr.opt.RemoteOpt.Inbounds, tcase.wantInbounds))
   614  		test.Assert(t, reflect.DeepEqual(iSvr.opt.RemoteOpt.Outbounds, tcase.wantOutbounds))
   615  		svr.Stop()
   616  	}
   617  }
   618  
   619  func TestInvokeHandlerWithContextBackup(t *testing.T) {
   620  	testInvokeHandlerWithSession(t, true, ":8888")
   621  	os.Setenv(localsession.SESSION_CONFIG_KEY, "true,100,1h")
   622  	testInvokeHandlerWithSession(t, false, ":8889")
   623  }
   624  
   625  func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) {
   626  	callMethod := "mock"
   627  	var opts []Option
   628  	mwExec := false
   629  	opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
   630  		return func(ctx context.Context, req, resp interface{}) (err error) {
   631  			mwExec = true
   632  			return next(ctx, req, resp)
   633  		}
   634  	}))
   635  
   636  	k1, v1 := "1", "1"
   637  	var backupHandler backup.BackupHandler
   638  	if !fail {
   639  		opts = append(opts, WithContextBackup(true, true))
   640  		backupHandler = func(prev, cur context.Context) (context.Context, bool) {
   641  			v := prev.Value(k1)
   642  			if v != nil {
   643  				cur = context.WithValue(cur, k1, v)
   644  				return cur, true
   645  			}
   646  			return cur, true
   647  		}
   648  	}
   649  
   650  	opts = append(opts, WithCodec(&mockCodec{}))
   651  	transHdlrFact := &mockSvrTransHandlerFactory{}
   652  	exitCh := make(chan bool)
   653  	var ln net.Listener
   654  	transSvr := &mocks.MockTransServer{
   655  		BootstrapServerFunc: func(net.Listener) error {
   656  			{ // mock server call
   657  				ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
   658  				ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   659  				recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
   660  				recvMsg.NewData(callMethod)
   661  				sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)
   662  
   663  				// inject kvs here
   664  				ctx = metainfo.WithPersistentValue(ctx, "a", "b")
   665  				ctx = context.WithValue(ctx, k1, v1)
   666  
   667  				_, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg)
   668  				test.Assert(t, err == nil, err)
   669  			}
   670  			<-exitCh
   671  			return nil
   672  		},
   673  		ShutdownFunc: func() error {
   674  			if ln != nil {
   675  				ln.Close()
   676  			}
   677  			exitCh <- true
   678  			return nil
   679  		},
   680  		CreateListenerFunc: func(addr net.Addr) (net.Listener, error) {
   681  			var err error
   682  			ln, err = net.Listen("tcp", ad)
   683  			return ln, err
   684  		},
   685  	}
   686  	opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr)))
   687  	opts = append(opts, WithTransHandlerFactory(transHdlrFact))
   688  
   689  	svr := NewServer(opts...)
   690  	time.AfterFunc(100*time.Millisecond, func() {
   691  		err := svr.Stop()
   692  		test.Assert(t, err == nil, err)
   693  	})
   694  	serviceHandler := false
   695  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) {
   696  		serviceHandler = true
   697  
   698  		wg := sync.WaitGroup{}
   699  		wg.Add(1)
   700  		go func() {
   701  			defer wg.Done()
   702  
   703  			// miss context here
   704  			ctx := backup.RecoverCtxOnDemands(context.Background(), backupHandler)
   705  
   706  			if !fail {
   707  				b, _ := metainfo.GetPersistentValue(ctx, "a")
   708  				test.Assert(t, b == "b", "can't get metainfo")
   709  				test.Assert(t, ctx.Value(k1) == v1, "can't get v1")
   710  			} else {
   711  				_, ok := metainfo.GetPersistentValue(ctx, "a")
   712  				test.Assert(t, !ok, "can get metainfo")
   713  				test.Assert(t, ctx.Value(k1) != v1, "can get v1")
   714  			}
   715  		}()
   716  		wg.Wait()
   717  
   718  		return &mocks.MyResponse{Name: "mock"}, nil
   719  	}))
   720  	test.Assert(t, err == nil)
   721  	err = svr.Run()
   722  	test.Assert(t, err == nil, err)
   723  	test.Assert(t, mwExec)
   724  	test.Assert(t, serviceHandler)
   725  }
   726  
   727  func TestInvokeHandlerExec(t *testing.T) {
   728  	callMethod := "mock"
   729  	var opts []Option
   730  	mwExec := false
   731  	opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
   732  		return func(ctx context.Context, req, resp interface{}) (err error) {
   733  			mwExec = true
   734  			return next(ctx, req, resp)
   735  		}
   736  	}))
   737  	opts = append(opts, WithCodec(&mockCodec{}))
   738  	transHdlrFact := &mockSvrTransHandlerFactory{}
   739  	exitCh := make(chan bool)
   740  	var ln net.Listener
   741  	transSvr := &mocks.MockTransServer{
   742  		BootstrapServerFunc: func(net.Listener) error {
   743  			{ // mock server call
   744  				ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
   745  				ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   746  				recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
   747  				recvMsg.NewData(callMethod)
   748  				sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)
   749  
   750  				_, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg)
   751  				test.Assert(t, err == nil, err)
   752  			}
   753  			<-exitCh
   754  			return nil
   755  		},
   756  		ShutdownFunc: func() error {
   757  			if ln != nil {
   758  				ln.Close()
   759  			}
   760  			exitCh <- true
   761  			return nil
   762  		},
   763  		CreateListenerFunc: func(addr net.Addr) (net.Listener, error) {
   764  			var err error
   765  			ln, err = net.Listen("tcp", ":8888")
   766  			return ln, err
   767  		},
   768  	}
   769  	opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr)))
   770  	opts = append(opts, WithTransHandlerFactory(transHdlrFact))
   771  
   772  	svr := NewServer(opts...)
   773  	time.AfterFunc(100*time.Millisecond, func() {
   774  		err := svr.Stop()
   775  		test.Assert(t, err == nil, err)
   776  	})
   777  	serviceHandler := false
   778  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) {
   779  		serviceHandler = true
   780  		return &mocks.MyResponse{Name: "mock"}, nil
   781  	}))
   782  	test.Assert(t, err == nil)
   783  	err = svr.Run()
   784  	test.Assert(t, err == nil, err)
   785  	test.Assert(t, mwExec)
   786  	test.Assert(t, serviceHandler)
   787  }
   788  
   789  func TestInvokeHandlerPanic(t *testing.T) {
   790  	callMethod := "mock"
   791  	var opts []Option
   792  	mwExec := false
   793  	opts = append(opts, WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint {
   794  		return func(ctx context.Context, req, resp interface{}) (err error) {
   795  			mwExec = true
   796  			return next(ctx, req, resp)
   797  		}
   798  	}))
   799  	opts = append(opts, WithCodec(&mockCodec{}))
   800  	transHdlrFact := &mockSvrTransHandlerFactory{}
   801  	exitCh := make(chan bool)
   802  	var ln net.Listener
   803  	transSvr := &mocks.MockTransServer{
   804  		BootstrapServerFunc: func(net.Listener) error {
   805  			{
   806  				// mock server call
   807  				ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
   808  				ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
   809  				recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
   810  				recvMsg.NewData(callMethod)
   811  				sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)
   812  
   813  				_, err := transHdlrFact.hdlr.OnMessage(ctx, recvMsg, sendMsg)
   814  				test.Assert(t, strings.Contains(err.Error(), "happened in biz handler"))
   815  			}
   816  			<-exitCh
   817  			return nil
   818  		},
   819  		ShutdownFunc: func() error {
   820  			ln.Close()
   821  			exitCh <- true
   822  			return nil
   823  		},
   824  		CreateListenerFunc: func(addr net.Addr) (net.Listener, error) {
   825  			var err error
   826  			ln, err = net.Listen("tcp", ":8888")
   827  			return ln, err
   828  		},
   829  	}
   830  	opts = append(opts, WithTransServerFactory(mocks.NewMockTransServerFactory(transSvr)))
   831  	opts = append(opts, WithTransHandlerFactory(transHdlrFact))
   832  
   833  	svr := NewServer(opts...)
   834  	time.AfterFunc(100*time.Millisecond, func() {
   835  		err := svr.Stop()
   836  		test.Assert(t, err == nil, err)
   837  	})
   838  	serviceHandler := false
   839  	err := svr.RegisterService(mocks.ServiceInfo(), mocks.MockFuncHandler(func(ctx context.Context, req *mocks.MyRequest) (r *mocks.MyResponse, err error) {
   840  		serviceHandler = true
   841  		panic("test")
   842  	}))
   843  	test.Assert(t, err == nil)
   844  	err = svr.Run()
   845  	test.Assert(t, err == nil, err)
   846  	test.Assert(t, mwExec)
   847  	test.Assert(t, serviceHandler)
   848  }
   849  
   850  func TestRegisterService(t *testing.T) {
   851  	svr := NewServer()
   852  	time.AfterFunc(time.Second, func() {
   853  		err := svr.Stop()
   854  		test.Assert(t, err == nil, err)
   855  	})
   856  
   857  	svr.Run()
   858  
   859  	test.PanicAt(t, func() {
   860  		_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   861  	}, func(err interface{}) bool {
   862  		if errMsg, ok := err.(string); ok {
   863  			return strings.Contains(errMsg, "server is running")
   864  		}
   865  		return true
   866  	})
   867  	svr.Stop()
   868  
   869  	svr = NewServer()
   870  	time.AfterFunc(time.Second, func() {
   871  		err := svr.Stop()
   872  		test.Assert(t, err == nil, err)
   873  	})
   874  
   875  	test.PanicAt(t, func() {
   876  		_ = svr.RegisterService(nil, mocks.MyServiceHandler())
   877  	}, func(err interface{}) bool {
   878  		if errMsg, ok := err.(string); ok {
   879  			return strings.Contains(errMsg, "svcInfo is nil")
   880  		}
   881  		return true
   882  	})
   883  
   884  	test.PanicAt(t, func() {
   885  		_ = svr.RegisterService(mocks.ServiceInfo(), nil)
   886  	}, func(err interface{}) bool {
   887  		if errMsg, ok := err.(string); ok {
   888  			return strings.Contains(errMsg, "handler is nil")
   889  		}
   890  		return true
   891  	})
   892  
   893  	test.PanicAt(t, func() {
   894  		_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService())
   895  		_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   896  	}, func(err interface{}) bool {
   897  		if errMsg, ok := err.(string); ok {
   898  			return strings.Contains(errMsg, "Service[MockService] is already defined")
   899  		}
   900  		return true
   901  	})
   902  
   903  	test.PanicAt(t, func() {
   904  		_ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService())
   905  	}, func(err interface{}) bool {
   906  		if errMsg, ok := err.(string); ok {
   907  			return strings.Contains(errMsg, "multiple fallback services cannot be registered")
   908  		}
   909  		return true
   910  	})
   911  	svr.Stop()
   912  
   913  	svr = NewServer()
   914  	time.AfterFunc(time.Second, func() {
   915  		err := svr.Stop()
   916  		test.Assert(t, err == nil, err)
   917  	})
   918  
   919  	_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
   920  	_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
   921  	err := svr.Run()
   922  	test.Assert(t, err != nil)
   923  	test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
   924  	svr.Stop()
   925  }
   926  
   927  type noopMetahandler struct{}
   928  
   929  func (noopMetahandler) WriteMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
   930  	return ctx, nil
   931  }
   932  
   933  func (noopMetahandler) ReadMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
   934  	return ctx, nil
   935  }
   936  func (noopMetahandler) OnConnectStream(ctx context.Context) (context.Context, error) { return ctx, nil }
   937  func (noopMetahandler) OnReadStream(ctx context.Context) (context.Context, error)    { return ctx, nil }
   938  
   939  type mockSvrTransHandlerFactory struct {
   940  	hdlr remote.ServerTransHandler
   941  }
   942  
   943  func (f *mockSvrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
   944  	f.hdlr, _ = trans.NewDefaultSvrTransHandler(opt, &mockExtension{})
   945  	return f.hdlr, nil
   946  }
   947  
   948  type mockExtension struct{}
   949  
   950  func (m mockExtension) SetReadTimeout(ctx context.Context, conn net.Conn, cfg rpcinfo.RPCConfig, role remote.RPCRole) {
   951  }
   952  
   953  func (m mockExtension) NewWriteByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   954  	return remote.NewWriterBuffer(0)
   955  }
   956  
   957  func (m mockExtension) NewReadByteBuffer(ctx context.Context, conn net.Conn, msg remote.Message) remote.ByteBuffer {
   958  	return remote.NewReaderBuffer(nil)
   959  }
   960  
   961  func (m mockExtension) ReleaseBuffer(buffer remote.ByteBuffer, err error) error {
   962  	return nil
   963  }
   964  
   965  func (m mockExtension) IsTimeoutErr(err error) bool {
   966  	return false
   967  }
   968  
   969  func (m mockExtension) IsRemoteClosedErr(err error) bool {
   970  	return false
   971  }
   972  
   973  type mockCodec struct {
   974  	EncodeFunc func(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error
   975  	DecodeFunc func(ctx context.Context, msg remote.Message, in remote.ByteBuffer) error
   976  }
   977  
   978  func (m *mockCodec) Name() string {
   979  	return "Mock"
   980  }
   981  
   982  func (m *mockCodec) Encode(ctx context.Context, msg remote.Message, out remote.ByteBuffer) (err error) {
   983  	if m.EncodeFunc != nil {
   984  		return m.EncodeFunc(ctx, msg, out)
   985  	}
   986  	return
   987  }
   988  
   989  func (m *mockCodec) Decode(ctx context.Context, msg remote.Message, in remote.ByteBuffer) (err error) {
   990  	if m.DecodeFunc != nil {
   991  		return m.DecodeFunc(ctx, msg, in)
   992  	}
   993  	return
   994  }
   995  
   996  func TestDuplicatedRegisterInfoPanic(t *testing.T) {
   997  	svcs := newServices()
   998  	svcs.addService(mocks.ServiceInfo(), nil, &RegisterOptions{})
   999  	s := &server{
  1000  		opt:  internal_server.NewOptions(nil),
  1001  		svcs: svcs,
  1002  	}
  1003  	s.init()
  1004  
  1005  	test.Panic(t, func() {
  1006  		_ = s.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
  1007  	})
  1008  }
  1009  
  1010  func TestRunServiceWithoutSvcInfo(t *testing.T) {
  1011  	svr := NewServer()
  1012  	time.AfterFunc(100*time.Millisecond, func() {
  1013  		_ = svr.Stop()
  1014  	})
  1015  	err := svr.Run()
  1016  	test.Assert(t, err != nil)
  1017  	test.Assert(t, strings.Contains(err.Error(), "no service"))
  1018  }
  1019  
  1020  func inboundDeepEqual(inbound1, inbound2 []remote.InboundHandler) bool {
  1021  	if len(inbound1) != len(inbound2) {
  1022  		return false
  1023  	}
  1024  	for i := 0; i < len(inbound1); i++ {
  1025  		if !bound.DeepEqual(inbound1[i], inbound2[i]) {
  1026  			return false
  1027  		}
  1028  	}
  1029  	return true
  1030  }
  1031  
  1032  func withGRPCTransport() Option {
  1033  	return Option{F: func(o *internal_server.Options, di *utils.Slice) {
  1034  		o.RemoteOpt.SvrHandlerFactory = nphttp2.NewSvrTransHandlerFactory()
  1035  	}}
  1036  }