trpc.group/trpc-go/trpc-go@v1.0.3/server/server_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package server_test
    15  
    16  import (
    17  	"context"
    18  	"os"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  
    25  	"trpc.group/trpc-go/trpc-go/codec"
    26  	"trpc.group/trpc-go/trpc-go/errs"
    27  	"trpc.group/trpc-go/trpc-go/naming/registry"
    28  	"trpc.group/trpc-go/trpc-go/restful"
    29  	"trpc.group/trpc-go/trpc-go/server"
    30  	"trpc.group/trpc-go/trpc-go/transport"
    31  )
    32  
    33  // Greeter defines service
    34  type GreeterServer interface {
    35  	SayHello(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error)
    36  	SayHi(Greeter_SayHiServer) error
    37  }
    38  
    39  // Greeter_SayHiServer defines server stream
    40  type Greeter_SayHiServer interface {
    41  	Send(*codec.Body) error
    42  	Recv() (*codec.Body, error)
    43  	server.Stream
    44  }
    45  
    46  // greeterSayHiServer server stream impl
    47  type greeterSayHiServer struct {
    48  	server.Stream
    49  }
    50  
    51  func (x *greeterSayHiServer) Send(m *codec.Body) error {
    52  	return nil
    53  }
    54  
    55  func (x *greeterSayHiServer) Recv() (*codec.Body, error) {
    56  	return nil, nil
    57  }
    58  
    59  func GreeterServerSayHelloHandler(svr interface{}, ctx context.Context,
    60  	f server.FilterFunc) (rspBody interface{}, err error) {
    61  	req := &codec.Body{}
    62  	filters, err := f(req)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	handleFunc := func(ctx context.Context, reqBody interface{}) (interface{}, error) {
    67  		return svr.(GreeterServer).SayHello(ctx, reqBody.(*codec.Body))
    68  	}
    69  	return filters.Filter(ctx, req, handleFunc)
    70  }
    71  
    72  type GreeterServerImpl struct{}
    73  type FailServerImpl struct{}
    74  
    75  func (s *GreeterServerImpl) SayHello(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
    76  	rsp = &codec.Body{}
    77  	rsp.Data = req.Data
    78  	if string(req.Data) == "handle-timeout" {
    79  		time.Sleep(time.Second * 2)
    80  	}
    81  	if string(req.Data) == "no-response" {
    82  		return nil, errs.ErrServerNoResponse
    83  	}
    84  	if string(req.Data) == "business-fail" {
    85  		return nil, errs.New(1000, "inner db fail")
    86  	}
    87  	return rsp, nil
    88  }
    89  
    90  func (s *GreeterServerImpl) SayHi(gs Greeter_SayHiServer) error {
    91  	return nil
    92  }
    93  
    94  type fakeStreamHandle struct {
    95  }
    96  
    97  func (fs *fakeStreamHandle) StreamHandleFunc(ctx context.Context, sh server.StreamHandler, si *server.StreamServerInfo, req []byte) ([]byte, error) {
    98  	return nil, nil
    99  }
   100  
   101  func (fs *fakeStreamHandle) Init(opts *server.Options) error {
   102  	return nil
   103  }
   104  
   105  func GreeterService_SayHi_Handler(srv interface{}, stream server.Stream) error {
   106  	return srv.(GreeterServer).SayHi(&greeterSayHiServer{stream})
   107  }
   108  
   109  // GreeterServer_ServiceDesc descriptor for server.RegisterService
   110  var GreeterServerServiceDesc = server.ServiceDesc{
   111  	ServiceName:  "trpc.test.helloworld.Greeter",
   112  	HandlerType:  (*GreeterServer)(nil),
   113  	StreamHandle: &fakeStreamHandle{},
   114  	Methods: []server.Method{
   115  		{
   116  			Name: "/trpc.test.helloworld.Greeter/SayHello",
   117  			Func: GreeterServerSayHelloHandler,
   118  			Bindings: []*restful.Binding{
   119  				{
   120  					Name:    "/trpc.test.helloworld.Greeter/SayHello",
   121  					Pattern: restful.Enforce("/v1/foobar"),
   122  				},
   123  			},
   124  		},
   125  	},
   126  	Streams: []server.StreamDesc{
   127  		{
   128  			StreamName:    "/trpc.test.helloworld.Greeter/SayHi",
   129  			Handler:       GreeterService_SayHi_Handler,
   130  			ServerStreams: true,
   131  		},
   132  	},
   133  }
   134  
   135  // GreeterServer_ServiceDesc descriptor for server.RegisterService
   136  var GreeterServerServiceDescFail = server.ServiceDesc{
   137  	ServiceName:  "trpc.test.helloworld.Greeter",
   138  	HandlerType:  (*GreeterServer)(nil),
   139  	StreamHandle: nil,
   140  	Methods: []server.Method{
   141  		{
   142  			Name: "/trpc.test.helloworld.Greeter/SayHello",
   143  			Func: GreeterServerSayHelloHandler,
   144  		},
   145  	},
   146  	Streams: []server.StreamDesc{
   147  		{
   148  			StreamName:    "/trpc.test.helloworld.Greeter/SayHi",
   149  			Handler:       GreeterService_SayHi_Handler,
   150  			ServerStreams: true,
   151  		},
   152  	},
   153  }
   154  
   155  func TestServeFail(t *testing.T) {
   156  	t.Run("test empty service", func(t *testing.T) {
   157  		s := &server.Server{}
   158  		assert.Panics(t, func() { s.Serve() }, "service empty")
   159  	})
   160  	t.Run("network mismatching", func(t *testing.T) {
   161  		s := &server.Server{}
   162  		s.AddService("trpc.test.helloworld.Greeter1", server.New(
   163  			server.WithNetwork("tcp9"),
   164  			server.WithAddress("127.0.0.1:8080"),
   165  			server.WithProtocol("trpc"),
   166  			server.WithServiceName("trpc.test.helloworld.Greeter1")))
   167  		assert.NotNil(t, s.Register(&GreeterServerServiceDesc, &FailServerImpl{}))
   168  		assert.NotNil(t, s.Serve())
   169  	})
   170  	t.Run("registry failure", func(t *testing.T) {
   171  		s := &server.Server{}
   172  		s.AddService("trpc.test.helloworld.Greeter", server.New(
   173  			server.WithAddress("127.0.0.1:8081"),
   174  			server.WithRegistry(&registry.NoopRegistry{})))
   175  		assert.NotNil(t, s.Register(&GreeterServerServiceDesc, &FailServerImpl{}))
   176  		assert.NotNil(t, s.Serve())
   177  	})
   178  }
   179  
   180  func TestServer(t *testing.T) {
   181  	// If the process is started by graceful restart,
   182  	// exit here in case of infinite loop.
   183  	if len(os.Getenv(transport.EnvGraceRestart)) > 0 {
   184  		t.SkipNow()
   185  	}
   186  	s := &server.Server{}
   187  
   188  	// 1. try to get service that not exists.
   189  	assert.Nil(t, s.Service("empty"))
   190  
   191  	service1 := server.New(server.WithAddress("127.0.0.1:12345"),
   192  		server.WithNetwork("tcp"),
   193  		server.WithProtocol("trpc"),
   194  		server.WithServiceName("trpc.test.helloworld.Greeter1"))
   195  
   196  	service2 := server.New(server.WithAddress("127.0.0.1:12346"),
   197  		server.WithNetwork("tcp"),
   198  		server.WithProtocol("trpc"),
   199  		server.WithServiceName("trpc.test.helloworld.Greeter2"))
   200  
   201  	s.AddService("trpc.test.helloworld.Greeter1", service1)
   202  	s.AddService("trpc.test.helloworld.Greeter2", service2)
   203  
   204  	assert.Equal(t, service1, s.Service("trpc.test.helloworld.Greeter1"))
   205  	assert.Equal(t, service2, s.Service("trpc.test.helloworld.Greeter2"))
   206  	assert.Nil(t, s.Service("empty"))
   207  
   208  	// 2. test registering empty proto service.
   209  	err := s.Register(nil, nil)
   210  	assert.NotNil(t, err)
   211  
   212  	impl := &GreeterServerImpl{}
   213  	err = s.Register(&GreeterServerServiceDesc, impl)
   214  	assert.Nil(t, err)
   215  
   216  	// 3. valid serving.
   217  	go func() {
   218  		err = os.Setenv(transport.EnvGraceRestart, "")
   219  		assert.Nil(t, err)
   220  
   221  		err = s.Serve()
   222  		assert.Nil(t, err)
   223  	}()
   224  
   225  	time.Sleep(time.Second * 1)
   226  	err = s.Close(nil)
   227  	assert.Nil(t, err)
   228  }
   229  
   230  func TestServerClose(t *testing.T) {
   231  	const schTime = 10 * time.Millisecond
   232  	cases := []struct {
   233  		maxCloseWaitTime time.Duration
   234  	}{
   235  		{},
   236  		{
   237  			maxCloseWaitTime: server.MaxCloseWaitTime / 2,
   238  		},
   239  		{
   240  			maxCloseWaitTime: server.MaxCloseWaitTime,
   241  		},
   242  		{
   243  			maxCloseWaitTime: server.MaxCloseWaitTime * 2,
   244  		},
   245  	}
   246  	for _, c := range cases {
   247  		s := &server.Server{
   248  			MaxCloseWaitTime: c.maxCloseWaitTime,
   249  		}
   250  		start := time.Now()
   251  		s.Close(nil)
   252  		et := time.Since(start)
   253  		assert.Less(t, et, schTime)
   254  	}
   255  }
   256  
   257  // TestServer_AtExit tests whether order of execution of shutdown hook functions matches
   258  // order of registration of shutdown hook functions.
   259  func TestServer_AtExit_ExecuteOrder(t *testing.T) {
   260  	s := &server.Server{}
   261  	const num = 3
   262  	ch := make(chan int, num)
   263  	for i := 0; i < num; i++ {
   264  		// temporary variable j helps capture the iteration variable i.
   265  		j := i
   266  		s.RegisterOnShutdown(func() { ch <- j })
   267  	}
   268  	s.RegisterOnShutdown(func() { close(ch) })
   269  
   270  	require.Nil(t, s.Close(nil))
   271  
   272  	for i := 0; i < num; i++ {
   273  		require.Equal(t, i, <-ch)
   274  	}
   275  	_, ok := <-ch
   276  	require.False(t, ok)
   277  }