trpc.group/trpc-go/trpc-go@v1.0.3/server/options_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package server_test
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"net"
    20  	"reflect"
    21  	"runtime"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"trpc.group/trpc-go/trpc-go/codec"
    29  	"trpc.group/trpc-go/trpc-go/filter"
    30  	"trpc.group/trpc-go/trpc-go/naming/registry"
    31  	"trpc.group/trpc-go/trpc-go/restful"
    32  	"trpc.group/trpc-go/trpc-go/server"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  
    35  	_ "trpc.group/trpc-go/trpc-go"
    36  )
    37  
    38  // go test -v -coverprofile=cover.out
    39  // go tool cover -func=cover.out
    40  
    41  var ctx = context.Background()
    42  
    43  type fakeHandler struct {
    44  }
    45  
    46  func (s *fakeHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) {
    47  	return req, nil
    48  }
    49  
    50  func TestOptions(t *testing.T) {
    51  
    52  	opts := &server.Options{}
    53  	transportOpts := &transport.ListenServeOptions{}
    54  
    55  	// WithServiceName
    56  	o := server.WithServiceName("trpc.test.helloworld")
    57  	o(opts)
    58  	assert.Equal(t, opts.ServiceName, "trpc.test.helloworld")
    59  
    60  	o = server.WithNamespace("Development")
    61  	o(opts)
    62  	assert.Equal(t, opts.Namespace, "Development")
    63  
    64  	o = server.WithEnvName("formal")
    65  	o(opts)
    66  	assert.Equal(t, opts.EnvName, "formal")
    67  
    68  	o = server.WithSetName("a.b.c")
    69  	o(opts)
    70  	assert.Equal(t, opts.SetName, "a.b.c")
    71  
    72  	// WithDisableRequestTimeout
    73  	assert.Equal(t, opts.DisableRequestTimeout, false) // false by default
    74  	o = server.WithDisableRequestTimeout(true)
    75  	o(opts)
    76  	assert.Equal(t, opts.DisableRequestTimeout, true)
    77  
    78  	// WithAddress
    79  	o = server.WithAddress("127.0.0.1:8080")
    80  	o(opts)
    81  	for _, o := range opts.ServeOptions {
    82  		o(transportOpts)
    83  	}
    84  	assert.Equal(t, transportOpts.Address, "127.0.0.1:8080")
    85  	assert.Equal(t, opts.Address, "127.0.0.1:8080")
    86  
    87  	// WithNetwork
    88  	o = server.WithNetwork("tcp")
    89  	o(opts)
    90  	for _, o := range opts.ServeOptions {
    91  		o(transportOpts)
    92  	}
    93  	assert.Equal(t, transportOpts.Network, "tcp")
    94  
    95  	lis, _ := net.Listen("tcp", "127.0.0.1:8080")
    96  	o = server.WithListener(lis)
    97  	o(opts)
    98  	for _, o := range opts.ServeOptions {
    99  		o(transportOpts)
   100  	}
   101  	assert.Equal(t, transportOpts.Listener, lis)
   102  	if lis != nil {
   103  		lis.Close()
   104  	}
   105  
   106  	o = server.WithTLS("server.crt", "server.key", "ca.pem")
   107  	o(opts)
   108  	for _, o := range opts.ServeOptions {
   109  		o(transportOpts)
   110  	}
   111  	assert.Equal(t, transportOpts.TLSCertFile, "server.crt")
   112  	assert.Equal(t, transportOpts.TLSKeyFile, "server.key")
   113  }
   114  
   115  func TestMoreOptions(t *testing.T) {
   116  	// WithHandler
   117  	h := &fakeHandler{}
   118  	o := server.WithHandler(h)
   119  	opts := &server.Options{}
   120  	transportOpts := &transport.ListenServeOptions{}
   121  	o(opts)
   122  	for _, o := range opts.ServeOptions {
   123  		o(transportOpts)
   124  	}
   125  	assert.Equal(t, transportOpts.Handler, h)
   126  
   127  	// WithTimeout
   128  	o = server.WithTimeout(time.Second)
   129  	o(opts)
   130  	for _, o := range opts.ServeOptions {
   131  		o(transportOpts)
   132  	}
   133  	assert.Equal(t, opts.Timeout, time.Second)
   134  
   135  	// WithTransport
   136  	o = server.WithTransport(transport.DefaultServerTransport)
   137  	o(opts)
   138  	assert.Equal(t, opts.Transport, transport.DefaultServerTransport)
   139  
   140  	// register ServerTransport
   141  	transport.RegisterServerTransport("trpc", transport.DefaultServerTransport)
   142  	// WithProtocol
   143  	o = server.WithProtocol("trpc")
   144  	o(opts)
   145  	for _, o := range opts.ServeOptions {
   146  		o(transportOpts)
   147  	}
   148  	assert.NotEqual(t, opts.Codec, nil)
   149  
   150  	o = server.WithProtocol("fake")
   151  	o(opts)
   152  	for _, o := range opts.ServeOptions {
   153  		o(transportOpts)
   154  	}
   155  
   156  	o = server.WithCurrentSerializationType(codec.SerializationTypeNoop)
   157  	o(opts)
   158  	assert.Equal(t, opts.CurrentSerializationType, codec.SerializationTypeNoop)
   159  
   160  	o = server.WithCurrentCompressType(codec.CompressTypeSnappy)
   161  	o(opts)
   162  	assert.Equal(t, opts.CurrentCompressType, codec.CompressTypeSnappy)
   163  
   164  	// WithFilter
   165  	o = server.WithFilter(filter.NoopServerFilter)
   166  	o(opts)
   167  	assert.Equal(t, len(opts.Filters), 1)
   168  
   169  	// WithFilters
   170  	o = server.WithFilters([]filter.ServerFilter{filter.NoopServerFilter})
   171  	o(opts)
   172  	assert.Equal(t, len(opts.Filters), 2)
   173  
   174  	// WithStreamFilter
   175  	sf1 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   176  		return nil
   177  	}
   178  	o = server.WithStreamFilter(sf1)
   179  	o(opts)
   180  	assert.Equal(t, 1, len(opts.StreamFilters))
   181  
   182  	// WithStreamFilters
   183  	sf2 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error {
   184  		return nil
   185  	}
   186  	o = server.WithStreamFilters(sf1, sf2)
   187  	o(opts)
   188  	assert.Equal(t, 3, len(opts.StreamFilters))
   189  
   190  	// WithRegistry
   191  	o = server.WithRegistry(registry.DefaultRegistry)
   192  	o(opts)
   193  	assert.Equal(t, registry.DefaultRegistry, opts.Registry)
   194  
   195  	// WithServerAsync
   196  	o = server.WithServerAsync(true)
   197  	o(opts)
   198  	for _, o := range opts.ServeOptions {
   199  		o(transportOpts)
   200  	}
   201  	assert.Equal(t, transportOpts.ServerAsync, true)
   202  
   203  	// WithMaxRoutines
   204  	server.WithMaxRoutines(100)(opts)
   205  	// WithWritev
   206  	server.WithWritev(true)(opts)
   207  	for _, o := range opts.ServeOptions {
   208  		o(transportOpts)
   209  	}
   210  	assert.Equal(t, transportOpts.Writev, true)
   211  
   212  	// WithMaxRoutines
   213  	o = server.WithMaxRoutines(100)
   214  	o(opts)
   215  	for _, o := range opts.ServeOptions {
   216  		o(transportOpts)
   217  	}
   218  	assert.Equal(t, transportOpts.Routines, 100)
   219  
   220  	// WithStreamTransport
   221  	o = server.WithStreamTransport(transport.DefaultServerStreamTransport)
   222  	o(opts)
   223  	assert.Equal(t, opts.StreamTransport, transport.DefaultServerStreamTransport)
   224  
   225  	// WithCloseWaitTime
   226  	o = server.WithCloseWaitTime(0 * time.Millisecond)
   227  	o(opts)
   228  	for _, o := range opts.ServeOptions {
   229  		o(transportOpts)
   230  	}
   231  	assert.Equal(t, opts.CloseWaitTime, 0*time.Millisecond)
   232  
   233  	// WithMaxCloseWaitTime
   234  	o = server.WithMaxCloseWaitTime(100 * time.Millisecond)
   235  	o(opts)
   236  	for _, o := range opts.ServeOptions {
   237  		o(transportOpts)
   238  	}
   239  	assert.Equal(t, opts.MaxCloseWaitTime, 100*time.Millisecond)
   240  
   241  	// WithRESTOptions
   242  	o1 := server.WithRESTOptions(restful.WithServiceName("name a"))
   243  	o2 := server.WithRESTOptions(restful.WithServiceName("name b"))
   244  	o1(opts)
   245  	o2(opts)
   246  	restOptions := &restful.Options{}
   247  	for _, o := range opts.RESTOptions {
   248  		o(restOptions)
   249  	}
   250  	assert.Equal(t, 2, len(opts.RESTOptions))
   251  	assert.Equal(t, "name b", restOptions.ServiceName)
   252  
   253  	// WithIdleTimeout
   254  	idleTimeout := time.Second
   255  	o = server.WithIdleTimeout(idleTimeout)
   256  	o(opts)
   257  	for _, o := range opts.ServeOptions {
   258  		o(transportOpts)
   259  	}
   260  	assert.Equal(t, transportOpts.IdleTimeout, idleTimeout)
   261  
   262  	// WithDisableKeepAlives
   263  	disableKeepAlives := true
   264  	o = server.WithDisableKeepAlives(disableKeepAlives)
   265  	o(opts)
   266  	for _, o := range opts.ServeOptions {
   267  		o(transportOpts)
   268  	}
   269  	assert.Equal(t, disableKeepAlives, transportOpts.DisableKeepAlives)
   270  
   271  	// WithMaxWindowSize
   272  	var maxWindowSize uint32 = 100
   273  	o = server.WithMaxWindowSize(maxWindowSize)
   274  	o(opts)
   275  	assert.Equal(t, maxWindowSize, opts.MaxWindowSize)
   276  }
   277  
   278  func TestWithNamedFilter(t *testing.T) {
   279  	var (
   280  		filterNames []string
   281  		filters     filter.ServerChain
   282  
   283  		sf = func(
   284  			ctx context.Context,
   285  			req interface{},
   286  			next filter.ServerHandleFunc,
   287  		) (rsp interface{}, err error) {
   288  			return next(ctx, req)
   289  		}
   290  	)
   291  	for i := 0; i < 10; i++ {
   292  		filterNames = append(filterNames, fmt.Sprintf("filter-%d", i))
   293  		filters = append(filters, sf)
   294  	}
   295  
   296  	var os []server.Option
   297  	for i := range filters {
   298  		os = append(os, server.WithNamedFilter(filterNames[i], filters[i]))
   299  	}
   300  
   301  	options := &server.Options{}
   302  	for _, o := range os {
   303  		o(options)
   304  	}
   305  	require.Equal(t, filterNames, options.FilterNames)
   306  	require.Equal(t, len(filters), len(options.Filters))
   307  	for i := range filters {
   308  		require.Equal(
   309  			t,
   310  			runtime.FuncForPC(reflect.ValueOf(filters[i]).Pointer()).Name(),
   311  			runtime.FuncForPC(reflect.ValueOf(options.Filters[i]).Pointer()).Name(),
   312  		)
   313  	}
   314  }