trpc.group/trpc-go/trpc-go@v1.0.3/server/service_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package server_test
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"math/rand"
    20  	"net"
    21  	"os"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"trpc.group/trpc-go/trpc-go/client"
    29  	"trpc.group/trpc-go/trpc-go/codec"
    30  	"trpc.group/trpc-go/trpc-go/errs"
    31  	"trpc.group/trpc-go/trpc-go/filter"
    32  	"trpc.group/trpc-go/trpc-go/log"
    33  	"trpc.group/trpc-go/trpc-go/naming/registry"
    34  	"trpc.group/trpc-go/trpc-go/restful"
    35  	"trpc.group/trpc-go/trpc-go/server"
    36  	pb "trpc.group/trpc-go/trpc-go/testdata/trpc/helloworld"
    37  	"trpc.group/trpc-go/trpc-go/transport"
    38  )
    39  
    40  func init() {
    41  	rand.Seed(time.Now().Unix())
    42  }
    43  
    44  // go test -v
    45  type fakeTransport struct {
    46  }
    47  
    48  func (s *fakeTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error {
    49  	lsopts := &transport.ListenServeOptions{}
    50  	for _, opt := range opts {
    51  		opt(lsopts)
    52  	}
    53  
    54  	go func() {
    55  		lsopts.Handler.Handle(ctx, []byte("normal-request"))
    56  		lsopts.Handler.Handle(ctx, []byte("stream"))
    57  		lsopts.Handler.Handle(ctx, []byte("no-rpc-name"))
    58  		lsopts.Handler.Handle(ctx, []byte("decode-error"))
    59  		lsopts.Handler.Handle(ctx, []byte("encode-error"))
    60  		lsopts.Handler.Handle(ctx, []byte("handle-timeout"))
    61  		lsopts.Handler.Handle(ctx, []byte("no-response"))
    62  		lsopts.Handler.Handle(ctx, []byte("business-fail"))
    63  		lsopts.Handler.Handle(ctx, []byte("handle-panic"))
    64  		lsopts.Handler.Handle(ctx, []byte("compress-error"))
    65  		lsopts.Handler.Handle(ctx, []byte("decompress-error"))
    66  		lsopts.Handler.Handle(ctx, []byte("unmarshal-error"))
    67  		lsopts.Handler.Handle(ctx, []byte("marshal-error"))
    68  		ctx := context.Background()
    69  		ctx, msg := codec.WithNewMessage(ctx)
    70  		msg.WithServerRspErr(errors.New("connection is tryClose "))
    71  		lsopts.Handler.Handle(ctx, nil)
    72  
    73  	}()
    74  
    75  	return nil
    76  }
    77  
    78  type fakeCodec struct {
    79  }
    80  
    81  func (c *fakeCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) {
    82  	req := string(reqBuf)
    83  
    84  	if req == "stream" {
    85  		msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHi")
    86  		return reqBuf, nil
    87  	}
    88  	if req != "no-rpc-name" {
    89  		msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHello")
    90  	}
    91  	if req == "decode-error" {
    92  		return nil, errors.New("server decode request fail")
    93  	}
    94  	msg.WithRequestTimeout(time.Second)
    95  	msg.WithSerializationType(codec.SerializationTypeNoop)
    96  	log.Infof("fakeCodec ==> req[%v]", req)
    97  	return reqBuf, nil
    98  }
    99  
   100  func (c *fakeCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) {
   101  	rsp := string(rspBody)
   102  	if rsp == "encode-error" {
   103  		return nil, errors.New("server encode response fail")
   104  	}
   105  	return rspBody, nil
   106  }
   107  
   108  func (c *fakeCodec) Compress(in []byte) (out []byte, err error) {
   109  	rsp := string(in)
   110  	if rsp == "compress-error" {
   111  		return nil, errors.New("server compress fail")
   112  	}
   113  	return in, nil
   114  }
   115  
   116  func (c *fakeCodec) Decompress(in []byte) (out []byte, err error) {
   117  	req := string(in)
   118  	if req == "decompress-error" {
   119  		return nil, errors.New("server decompress fail")
   120  	}
   121  	return in, nil
   122  }
   123  
   124  func (c *fakeCodec) Unmarshal(reqBuf []byte, reqBody interface{}) error {
   125  	req := string(reqBuf)
   126  	if req == "unmarshal-error" {
   127  		return errors.New("server unmarshal fail")
   128  	}
   129  	return codec.Unmarshal(codec.SerializationTypeNoop, reqBuf, reqBody)
   130  }
   131  
   132  func (c *fakeCodec) Marshal(rspBody interface{}) (rspBuf []byte, err error) {
   133  	if rsp, ok := rspBody.(*codec.Body); ok {
   134  		if string(rsp.Data) == "marshal-error" {
   135  			return nil, errors.New("server marshal fail")
   136  		}
   137  	}
   138  	return codec.Marshal(codec.SerializationTypeNoop, rspBody)
   139  }
   140  
   141  type fakeRegistry struct {
   142  }
   143  
   144  func (r *fakeRegistry) Register(service string, opt ...registry.Option) error {
   145  	return nil
   146  }
   147  func (r *fakeRegistry) Deregister(service string) error {
   148  	return nil
   149  }
   150  
   151  func TestService(t *testing.T) {
   152  	codec.Register("fake", &fakeCodec{}, nil)
   153  	// register the fake codec
   154  	codec.RegisterCompressor(930, &fakeCodec{})
   155  	codec.RegisterSerializer(1930, &fakeCodec{})
   156  
   157  	// 1.codec not set,transport will cause error.
   158  	service := server.New(server.WithServiceName("trpc.test.helloworld.Greeter"),
   159  		server.WithTransport(&fakeTransport{}),
   160  		server.WithRegistry(&registry.NoopRegistry{}))
   161  
   162  	impl := &GreeterServerImpl{}
   163  	err := service.Register(&GreeterServerServiceDesc, impl)
   164  	assert.Nil(t, err)
   165  
   166  	go func() {
   167  		_ = service.Serve()
   168  	}()
   169  	// closing service will not return error even if registry fails.
   170  	err = service.Close(nil)
   171  	assert.Nil(t, err)
   172  
   173  	// 2. valid service registration
   174  	service = server.New(server.WithProtocol("fake"),
   175  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   176  		server.WithTransport(&fakeTransport{}),
   177  		server.WithRegistry(&fakeRegistry{}),
   178  		server.WithCurrentSerializationType(1930),
   179  		server.WithCurrentCompressType(930),
   180  		server.WithCloseWaitTime(100*time.Millisecond),
   181  		server.WithMaxCloseWaitTime(200*time.Millisecond))
   182  	err = service.Register(&GreeterServerServiceDesc, impl)
   183  	assert.Nil(t, err)
   184  
   185  	// RESTful router should exist
   186  	assert.NotNil(t, restful.GetRouter("trpc.test.helloworld.Greeter"))
   187  
   188  	go func() {
   189  		_ = service.Serve()
   190  	}()
   191  	time.Sleep(time.Second * 2)
   192  	err = service.Close(nil)
   193  	assert.Nil(t, err)
   194  }
   195  
   196  // TestServiceFail tests failures of request handling.
   197  func TestServiceFail(t *testing.T) {
   198  
   199  	codec.Register("fake", &fakeCodec{}, nil)
   200  	service := server.New(server.WithProtocol("fake"),
   201  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   202  		server.WithTransport(&fakeTransport{}),
   203  		server.WithRegistry(&fakeRegistry{}),
   204  	)
   205  
   206  	impl := &GreeterServerImpl{}
   207  	err := service.Register(&GreeterServerServiceDescFail, impl)
   208  	assert.Nil(t, err)
   209  	go func() {
   210  		service.Serve()
   211  	}()
   212  
   213  	time.Sleep(time.Second * 2)
   214  }
   215  
   216  // TestServiceMethodNameUniqueness tests method name uniqueness
   217  func TestServiceMethodNameUniqueness(t *testing.T) {
   218  	codec.Register("fake", &fakeCodec{}, nil)
   219  	service := server.New(server.WithProtocol("fake"),
   220  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   221  		server.WithTransport(&fakeTransport{}),
   222  		server.WithRegistry(&fakeRegistry{}),
   223  	)
   224  
   225  	impl := &GreeterServerImpl{}
   226  	err := service.Register(&GreeterServerServiceDescFail, impl)
   227  	assert.Nil(t, err)
   228  
   229  	err = service.Register(&GreeterServerServiceDescFail, impl)
   230  	assert.NotNil(t, err)
   231  }
   232  
   233  func TestServiceTimeout(t *testing.T) {
   234  	require.Nil(t, os.Setenv(transport.EnvGraceRestart, ""))
   235  	t.Run("server timeout", func(t *testing.T) {
   236  		addr, stop := startService(t, &GreeterServerImpl{},
   237  			server.WithTimeout(time.Second),
   238  			server.WithFilter(
   239  				func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) {
   240  					return nil, errs.NewFrameError(errs.RetServerTimeout, "")
   241  				}))
   242  		defer stop()
   243  
   244  		c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   245  		_, err := c.SayHello(context.Background(), &pb.HelloRequest{})
   246  		require.NotNil(t, err)
   247  		e, ok := err.(*errs.Error)
   248  		require.True(t, ok)
   249  		require.EqualValues(t, int32(errs.RetServerTimeout), e.Code)
   250  	})
   251  	t.Run("client full link timeout is converted to server timeout",
   252  		func(t *testing.T) {
   253  			addr, stop := startService(t,
   254  				&Greeter{
   255  					sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   256  						return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "")
   257  					}},
   258  				server.WithTimeout(time.Second))
   259  			defer stop()
   260  
   261  			c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   262  			_, err := c.SayHello(ctx, &pb.HelloRequest{})
   263  			require.NotNil(t, err)
   264  			e, ok := err.(*errs.Error)
   265  			require.True(t, ok)
   266  			require.Equal(t, errs.ErrorTypeCalleeFramework, e.Type)
   267  			require.EqualValues(t, int32(errs.RetServerTimeout), e.Code)
   268  		})
   269  	t.Run("client full link timeout is converted to server full link timeout, and then dropped",
   270  		func(t *testing.T) {
   271  			addr, stop := startService(t,
   272  				&Greeter{
   273  					sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   274  						return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "")
   275  					}},
   276  				server.WithTimeout(time.Second*2))
   277  			defer stop()
   278  
   279  			c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   280  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   281  			defer cancel()
   282  			_, err := c.SayHello(ctx, &pb.HelloRequest{})
   283  			require.NotNil(t, err)
   284  			e, ok := err.(*errs.Error)
   285  			require.True(t, ok)
   286  			require.Equal(t, errs.ErrorTypeFramework, e.Type)
   287  			require.EqualValues(t, int32(errs.RetClientFullLinkTimeout), e.Code,
   288  				"server full link timeout is dropped, and client should receive a client timeout error")
   289  		})
   290  }
   291  
   292  func TestServiceUDP(t *testing.T) {
   293  	addr := "127.0.0.1:10000"
   294  	s := server.New([]server.Option{
   295  		server.WithNetwork("udp"),
   296  		server.WithProtocol("trpc"),
   297  		server.WithAddress(addr),
   298  		server.WithCurrentSerializationType(codec.SerializationTypeNoop),
   299  	}...)
   300  	require.Nil(t, s.Register(&GreeterServerServiceDesc, &GreeterServerImpl{}))
   301  	go s.Serve()
   302  	time.Sleep(time.Millisecond * 200)
   303  
   304  	c := pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr), client.WithNetwork("udp"))
   305  	_, err := c.SayHello(context.Background(), &pb.HelloRequest{})
   306  	require.Nil(t, err)
   307  }
   308  
   309  func TestCloseWaitTime(t *testing.T) {
   310  	startService := func(opts ...server.Option) (chan struct{}, func()) {
   311  		received, done := make(chan struct{}), make(chan struct{})
   312  		addr, stop := startService(t, &Greeter{}, append([]server.Option{server.WithFilter(
   313  			func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) {
   314  				received <- struct{}{}
   315  				<-done
   316  				return nil, errors.New("must fail")
   317  			})}, opts...)...)
   318  		go func() {
   319  			_, _ = pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr)).
   320  				SayHello(context.Background(), &pb.HelloRequest{})
   321  		}()
   322  		<-received
   323  		return done, stop
   324  	}
   325  	t.Run("active requests feature is not enabled on missing MaxCloseWaitTime", func(t *testing.T) {
   326  		done, stop := startService()
   327  		defer close(done)
   328  		start := time.Now()
   329  		stop()
   330  		require.Less(t, time.Since(start), time.Millisecond*100)
   331  	})
   332  	t.Run("total wait time should not significantly greater than MaxCloseWaitTime", func(t *testing.T) {
   333  		const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second
   334  		done, stop := startService(
   335  			server.WithMaxCloseWaitTime(maxCloseWaitTime),
   336  			server.WithCloseWaitTime(closeWaitTime))
   337  		defer close(done)
   338  		start := time.Now()
   339  		stop()
   340  		require.WithinRange(t, time.Now(),
   341  			// 300ms comes from the internal implementation when close service
   342  			start.Add(maxCloseWaitTime).Add(time.Millisecond*300),
   343  			start.Add(maxCloseWaitTime).Add(time.Millisecond*500))
   344  	})
   345  	t.Run("total wait time is at least CloseWaitTime", func(t *testing.T) {
   346  		const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second
   347  		done, stop := startService(
   348  			server.WithMaxCloseWaitTime(maxCloseWaitTime),
   349  			server.WithCloseWaitTime(closeWaitTime))
   350  		start := time.Now()
   351  		time.AfterFunc(closeWaitTime/2, func() { close(done) })
   352  		stop()
   353  		require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(closeWaitTime+time.Millisecond*100))
   354  	})
   355  	t.Run("no active request before MaxCloseWaitTime", func(t *testing.T) {
   356  		const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second
   357  		done, stop := startService(
   358  			server.WithMaxCloseWaitTime(maxCloseWaitTime),
   359  			server.WithCloseWaitTime(closeWaitTime))
   360  		start := time.Now()
   361  		time.AfterFunc((closeWaitTime+maxCloseWaitTime)/2, func() { close(done) })
   362  		stop()
   363  		require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(maxCloseWaitTime))
   364  	})
   365  	t.Run("no active request before service timeout", func(t *testing.T) {
   366  		const closeWaitTime, maxCloseWaitTime, timeout = time.Millisecond * 500, time.Second, time.Second
   367  		done, stop := startService(
   368  			server.WithMaxCloseWaitTime(maxCloseWaitTime),
   369  			server.WithCloseWaitTime(closeWaitTime),
   370  			server.WithTimeout(timeout))
   371  		start := time.Now()
   372  		time.AfterFunc(maxCloseWaitTime+time.Millisecond*100, func() { close(done) })
   373  		stop()
   374  		require.WithinRange(t, time.Now(), start.Add(maxCloseWaitTime+time.Millisecond*100), start.Add(maxCloseWaitTime+timeout))
   375  	})
   376  }
   377  
   378  func startService(t *testing.T, gs GreeterServer, opts ...server.Option) (addr string, stop func()) {
   379  	l, err := net.Listen("tcp", "0.0.0.0:0")
   380  	require.Nil(t, err)
   381  
   382  	s := server.New(append(append(
   383  		[]server.Option{
   384  			server.WithNetwork("tcp"),
   385  			server.WithProtocol("trpc"),
   386  		}, opts...),
   387  		server.WithListener(l),
   388  	)...)
   389  	require.Nil(t, s.Register(&GreeterServerServiceDesc, gs))
   390  
   391  	errCh := make(chan error)
   392  	go func() { errCh <- s.Serve() }()
   393  	select {
   394  	case err := <-errCh:
   395  		require.FailNow(t, "serve failed", err)
   396  	case <-time.After(time.Millisecond * 200):
   397  	}
   398  	return l.Addr().String(), func() { s.Close(nil) }
   399  }
   400  
   401  func TestGetStreamFilter(t *testing.T) {
   402  	expectedErr := errors.New("expected error")
   403  	testFilter := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   404  		return expectedErr
   405  	}
   406  	server.RegisterStreamFilter("testFilter", testFilter)
   407  	filter := server.GetStreamFilter("testFilter")
   408  	err := filter(nil, &server.StreamServerInfo{}, nil)
   409  	assert.Equal(t, expectedErr, err)
   410  }
   411  
   412  type Greeter struct {
   413  	sayHello func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error)
   414  }
   415  
   416  func (g *Greeter) SayHello(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   417  	return g.sayHello(ctx, req)
   418  }
   419  
   420  func (*Greeter) SayHi(gs Greeter_SayHiServer) error {
   421  	return nil
   422  }
   423  
   424  func TestStreamFilterChainFilter(t *testing.T) {
   425  	ch := make(chan int, 10)
   426  	sf1 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   427  		ch <- 1
   428  		err := handler(ss)
   429  		ch <- 5
   430  		return err
   431  	}
   432  	sf2 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   433  		ch <- 2
   434  		err := handler(ss)
   435  		ch <- 4
   436  		return err
   437  	}
   438  	option := server.WithStreamFilters(sf1, sf2)
   439  	options := server.Options{}
   440  	option(&options)
   441  	_ = options.StreamFilters.Filter(nil, nil, func(stream server.Stream) error {
   442  		ch <- 3
   443  		return nil
   444  	})
   445  	assert.Equal(t, 1, <-ch)
   446  	assert.Equal(t, 2, <-ch)
   447  	assert.Equal(t, 3, <-ch)
   448  	assert.Equal(t, 4, <-ch)
   449  	assert.Equal(t, 5, <-ch)
   450  }