trpc.group/trpc-go/trpc-go@v1.0.2/server/service_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package server_test
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"math/rand"
    20  	"net"
    21  	"os"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"trpc.group/trpc-go/trpc-go/client"
    29  	"trpc.group/trpc-go/trpc-go/codec"
    30  	"trpc.group/trpc-go/trpc-go/errs"
    31  	"trpc.group/trpc-go/trpc-go/filter"
    32  	"trpc.group/trpc-go/trpc-go/log"
    33  	"trpc.group/trpc-go/trpc-go/naming/registry"
    34  	"trpc.group/trpc-go/trpc-go/restful"
    35  	"trpc.group/trpc-go/trpc-go/server"
    36  	pb "trpc.group/trpc-go/trpc-go/testdata/trpc/helloworld"
    37  	"trpc.group/trpc-go/trpc-go/transport"
    38  )
    39  
    40  func init() {
    41  	rand.Seed(time.Now().Unix())
    42  }
    43  
    44  // go test -v
    45  type fakeTransport struct {
    46  }
    47  
    48  func (s *fakeTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error {
    49  	lsopts := &transport.ListenServeOptions{}
    50  	for _, opt := range opts {
    51  		opt(lsopts)
    52  	}
    53  
    54  	go func() {
    55  		lsopts.Handler.Handle(ctx, []byte("normal-request"))
    56  		lsopts.Handler.Handle(ctx, []byte("stream"))
    57  		lsopts.Handler.Handle(ctx, []byte("no-rpc-name"))
    58  		lsopts.Handler.Handle(ctx, []byte("decode-error"))
    59  		lsopts.Handler.Handle(ctx, []byte("encode-error"))
    60  		lsopts.Handler.Handle(ctx, []byte("handle-timeout"))
    61  		lsopts.Handler.Handle(ctx, []byte("no-response"))
    62  		lsopts.Handler.Handle(ctx, []byte("business-fail"))
    63  		lsopts.Handler.Handle(ctx, []byte("handle-panic"))
    64  		lsopts.Handler.Handle(ctx, []byte("compress-error"))
    65  		lsopts.Handler.Handle(ctx, []byte("decompress-error"))
    66  		lsopts.Handler.Handle(ctx, []byte("unmarshal-error"))
    67  		lsopts.Handler.Handle(ctx, []byte("marshal-error"))
    68  		ctx := context.Background()
    69  		ctx, msg := codec.WithNewMessage(ctx)
    70  		msg.WithServerRspErr(errors.New("connection is tryClose "))
    71  		lsopts.Handler.Handle(ctx, nil)
    72  
    73  	}()
    74  
    75  	return nil
    76  }
    77  
    78  type fakeCodec struct {
    79  }
    80  
    81  func (c *fakeCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) {
    82  	req := string(reqBuf)
    83  
    84  	if req == "stream" {
    85  		msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHi")
    86  		return reqBuf, nil
    87  	}
    88  	if req != "no-rpc-name" {
    89  		msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHello")
    90  	}
    91  	if req == "decode-error" {
    92  		return nil, errors.New("server decode request fail")
    93  	}
    94  	msg.WithRequestTimeout(time.Second)
    95  	msg.WithSerializationType(codec.SerializationTypeNoop)
    96  	log.Infof("fakeCodec ==> req[%v]", req)
    97  	return reqBuf, nil
    98  }
    99  
   100  func (c *fakeCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) {
   101  	rsp := string(rspBody)
   102  	if rsp == "encode-error" {
   103  		return nil, errors.New("server encode response fail")
   104  	}
   105  	return rspBody, nil
   106  }
   107  
   108  func (c *fakeCodec) Compress(in []byte) (out []byte, err error) {
   109  	rsp := string(in)
   110  	if rsp == "compress-error" {
   111  		return nil, errors.New("server compress fail")
   112  	}
   113  	return in, nil
   114  }
   115  
   116  func (c *fakeCodec) Decompress(in []byte) (out []byte, err error) {
   117  	req := string(in)
   118  	if req == "decompress-error" {
   119  		return nil, errors.New("server decompress fail")
   120  	}
   121  	return in, nil
   122  }
   123  
   124  func (c *fakeCodec) Unmarshal(reqBuf []byte, reqBody interface{}) error {
   125  	req := string(reqBuf)
   126  	if req == "unmarshal-error" {
   127  		return errors.New("server unmarshal fail")
   128  	}
   129  	return codec.Unmarshal(codec.SerializationTypeNoop, reqBuf, reqBody)
   130  }
   131  
   132  func (c *fakeCodec) Marshal(rspBody interface{}) (rspBuf []byte, err error) {
   133  	if rsp, ok := rspBody.(*codec.Body); ok {
   134  		if string(rsp.Data) == "marshal-error" {
   135  			return nil, errors.New("server marshal fail")
   136  		}
   137  	}
   138  	return codec.Marshal(codec.SerializationTypeNoop, rspBody)
   139  }
   140  
   141  type fakeRegistry struct {
   142  }
   143  
   144  func (r *fakeRegistry) Register(service string, opt ...registry.Option) error {
   145  	return nil
   146  }
   147  func (r *fakeRegistry) Deregister(service string) error {
   148  	return nil
   149  }
   150  
   151  func TestService(t *testing.T) {
   152  	codec.Register("fake", &fakeCodec{}, nil)
   153  	// register the fake codec
   154  	codec.RegisterCompressor(930, &fakeCodec{})
   155  	codec.RegisterSerializer(1930, &fakeCodec{})
   156  
   157  	// 1.codec not set,transport will cause error.
   158  	service := server.New(server.WithServiceName("trpc.test.helloworld.Greeter"),
   159  		server.WithTransport(&fakeTransport{}),
   160  		server.WithRegistry(&registry.NoopRegistry{}))
   161  
   162  	impl := &GreeterServerImpl{}
   163  	err := service.Register(&GreeterServerServiceDesc, impl)
   164  	assert.Nil(t, err)
   165  
   166  	go func() {
   167  		_ = service.Serve()
   168  	}()
   169  	// closing service will not return error even if registry fails.
   170  	err = service.Close(nil)
   171  	assert.Nil(t, err)
   172  
   173  	// 2. valid service registration
   174  	service = server.New(server.WithProtocol("fake"),
   175  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   176  		server.WithTransport(&fakeTransport{}),
   177  		server.WithRegistry(&fakeRegistry{}),
   178  		server.WithCurrentSerializationType(1930),
   179  		server.WithCurrentCompressType(930),
   180  		server.WithCloseWaitTime(100*time.Millisecond),
   181  		server.WithMaxCloseWaitTime(200*time.Millisecond))
   182  	err = service.Register(&GreeterServerServiceDesc, impl)
   183  	assert.Nil(t, err)
   184  
   185  	// RESTful router should exist
   186  	assert.NotNil(t, restful.GetRouter("trpc.test.helloworld.Greeter"))
   187  
   188  	go func() {
   189  		_ = service.Serve()
   190  	}()
   191  	time.Sleep(time.Second * 2)
   192  	err = service.Close(nil)
   193  	assert.Nil(t, err)
   194  }
   195  
   196  // TestServiceFail tests failures of request handling.
   197  func TestServiceFail(t *testing.T) {
   198  
   199  	codec.Register("fake", &fakeCodec{}, nil)
   200  	service := server.New(server.WithProtocol("fake"),
   201  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   202  		server.WithTransport(&fakeTransport{}),
   203  		server.WithRegistry(&fakeRegistry{}),
   204  	)
   205  
   206  	impl := &GreeterServerImpl{}
   207  	err := service.Register(&GreeterServerServiceDescFail, impl)
   208  	assert.Nil(t, err)
   209  	go func() {
   210  		service.Serve()
   211  	}()
   212  
   213  	time.Sleep(time.Second * 2)
   214  }
   215  
   216  // TestServiceMethodNameUniqueness tests method name uniqueness
   217  func TestServiceMethodNameUniqueness(t *testing.T) {
   218  	codec.Register("fake", &fakeCodec{}, nil)
   219  	service := server.New(server.WithProtocol("fake"),
   220  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   221  		server.WithTransport(&fakeTransport{}),
   222  		server.WithRegistry(&fakeRegistry{}),
   223  	)
   224  
   225  	impl := &GreeterServerImpl{}
   226  	err := service.Register(&GreeterServerServiceDescFail, impl)
   227  	assert.Nil(t, err)
   228  
   229  	err = service.Register(&GreeterServerServiceDescFail, impl)
   230  	assert.NotNil(t, err)
   231  }
   232  
   233  func TestServiceTimeout(t *testing.T) {
   234  	require.Nil(t, os.Setenv(transport.EnvGraceRestart, ""))
   235  	t.Run("server timeout", func(t *testing.T) {
   236  		addr, stop := startService(t, &GreeterServerImpl{},
   237  			server.WithTimeout(time.Second),
   238  			server.WithFilter(
   239  				func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) {
   240  					return nil, errs.NewFrameError(errs.RetServerTimeout, "")
   241  				}))
   242  		defer stop()
   243  
   244  		c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   245  		_, err := c.SayHello(context.Background(), &pb.HelloRequest{})
   246  		require.NotNil(t, err)
   247  		e, ok := err.(*errs.Error)
   248  		require.True(t, ok)
   249  		require.EqualValues(t, int32(errs.RetServerTimeout), e.Code)
   250  	})
   251  	t.Run("client full link timeout is converted to server timeout",
   252  		func(t *testing.T) {
   253  			addr, stop := startService(t,
   254  				&Greeter{
   255  					sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   256  						return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "")
   257  					}},
   258  				server.WithTimeout(time.Second))
   259  			defer stop()
   260  
   261  			c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   262  			_, err := c.SayHello(ctx, &pb.HelloRequest{})
   263  			require.NotNil(t, err)
   264  			e, ok := err.(*errs.Error)
   265  			require.True(t, ok)
   266  			require.Equal(t, errs.ErrorTypeCalleeFramework, e.Type)
   267  			require.EqualValues(t, int32(errs.RetServerTimeout), e.Code)
   268  		})
   269  	t.Run("client full link timeout is converted to server full link timeout, and then dropped",
   270  		func(t *testing.T) {
   271  			addr, stop := startService(t,
   272  				&Greeter{
   273  					sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   274  						return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "")
   275  					}},
   276  				server.WithTimeout(time.Second*2))
   277  			defer stop()
   278  
   279  			c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr))
   280  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   281  			defer cancel()
   282  			_, err := c.SayHello(ctx, &pb.HelloRequest{})
   283  			require.NotNil(t, err)
   284  			e, ok := err.(*errs.Error)
   285  			require.True(t, ok)
   286  			require.Equal(t, errs.ErrorTypeFramework, e.Type)
   287  			require.EqualValues(t, int32(errs.RetClientFullLinkTimeout), e.Code,
   288  				"server full link timeout is dropped, and client should receive a client timeout error")
   289  		})
   290  }
   291  
   292  func TestServiceUDP(t *testing.T) {
   293  	addr := "127.0.0.1:10000"
   294  	s := server.New([]server.Option{
   295  		server.WithNetwork("udp"),
   296  		server.WithProtocol("trpc"),
   297  		server.WithAddress(addr),
   298  		server.WithCurrentSerializationType(codec.SerializationTypeNoop),
   299  	}...)
   300  	require.Nil(t, s.Register(&GreeterServerServiceDesc, &GreeterServerImpl{}))
   301  	go s.Serve()
   302  	time.Sleep(time.Millisecond * 200)
   303  
   304  	c := pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr), client.WithNetwork("udp"))
   305  	_, err := c.SayHello(context.Background(), &pb.HelloRequest{})
   306  	require.Nil(t, err)
   307  }
   308  
   309  func TestServiceCloseWait(t *testing.T) {
   310  	const waitChildTime = 300 * time.Millisecond
   311  	const schTime = 10 * time.Millisecond
   312  	cases := []struct {
   313  		closeWaitTime    time.Duration
   314  		maxCloseWaitTime time.Duration
   315  		waitTime         time.Duration
   316  	}{
   317  		{
   318  			waitTime: waitChildTime,
   319  		},
   320  		{
   321  			closeWaitTime: 50 * time.Millisecond,
   322  			waitTime:      waitChildTime + 50*time.Millisecond,
   323  		},
   324  		{
   325  			closeWaitTime:    50 * time.Millisecond,
   326  			maxCloseWaitTime: 30 * time.Millisecond,
   327  			waitTime:         waitChildTime + 50*time.Millisecond,
   328  		},
   329  		{
   330  			closeWaitTime:    50 * time.Millisecond,
   331  			maxCloseWaitTime: 100 * time.Millisecond,
   332  			waitTime:         waitChildTime + 50*time.Millisecond,
   333  		},
   334  	}
   335  	for _, c := range cases {
   336  		service := server.New(
   337  			server.WithRegistry(&fakeRegistry{}),
   338  			server.WithCloseWaitTime(c.closeWaitTime),
   339  			server.WithMaxCloseWaitTime(c.maxCloseWaitTime),
   340  		)
   341  		start := time.Now()
   342  		err := service.Close(nil)
   343  		assert.Nil(t, err)
   344  		cost := time.Since(start)
   345  		assert.GreaterOrEqual(t, cost, c.waitTime)
   346  		assert.LessOrEqual(t, cost, c.waitTime+schTime)
   347  	}
   348  }
   349  
   350  func startService(t *testing.T, gs GreeterServer, opts ...server.Option) (addr string, stop func()) {
   351  	l, err := net.Listen("tcp", "0.0.0.0:0")
   352  	require.Nil(t, err)
   353  
   354  	s := server.New(append(append([]server.Option{
   355  		server.WithNetwork("tcp"),
   356  		server.WithProtocol("trpc"),
   357  	}, opts...),
   358  		server.WithListener(l),
   359  	)...)
   360  	require.Nil(t, s.Register(&GreeterServerServiceDesc, gs))
   361  
   362  	errCh := make(chan error)
   363  	go func() { errCh <- s.Serve() }()
   364  	select {
   365  	case err := <-errCh:
   366  		require.FailNow(t, "serve failed", err)
   367  	case <-time.After(time.Millisecond * 200):
   368  	}
   369  	return l.Addr().String(), func() { s.Close(nil) }
   370  }
   371  
   372  func TestGetStreamFilter(t *testing.T) {
   373  	expectedErr := errors.New("expected error")
   374  	testFilter := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   375  		return expectedErr
   376  	}
   377  	server.RegisterStreamFilter("testFilter", testFilter)
   378  	filter := server.GetStreamFilter("testFilter")
   379  	err := filter(nil, &server.StreamServerInfo{}, nil)
   380  	assert.Equal(t, expectedErr, err)
   381  }
   382  
   383  type Greeter struct {
   384  	sayHello func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error)
   385  }
   386  
   387  func (g *Greeter) SayHello(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) {
   388  	return g.sayHello(ctx, req)
   389  }
   390  
   391  func (*Greeter) SayHi(gs Greeter_SayHiServer) error {
   392  	return nil
   393  }
   394  
   395  func TestStreamFilterChainFilter(t *testing.T) {
   396  	ch := make(chan int, 10)
   397  	sf1 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   398  		ch <- 1
   399  		err := handler(ss)
   400  		ch <- 5
   401  		return err
   402  	}
   403  	sf2 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   404  		ch <- 2
   405  		err := handler(ss)
   406  		ch <- 4
   407  		return err
   408  	}
   409  	option := server.WithStreamFilters(sf1, sf2)
   410  	options := server.Options{}
   411  	option(&options)
   412  	_ = options.StreamFilters.Filter(nil, nil, func(stream server.Stream) error {
   413  		ch <- 3
   414  		return nil
   415  	})
   416  	assert.Equal(t, 1, <-ch)
   417  	assert.Equal(t, 2, <-ch)
   418  	assert.Equal(t, 3, <-ch)
   419  	assert.Equal(t, 4, <-ch)
   420  	assert.Equal(t, 5, <-ch)
   421  }