trpc.group/trpc-go/trpc-go@v1.0.3/server/options_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package server_test 15 16 import ( 17 "context" 18 "fmt" 19 "net" 20 "reflect" 21 "runtime" 22 "testing" 23 "time" 24 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 28 "trpc.group/trpc-go/trpc-go/codec" 29 "trpc.group/trpc-go/trpc-go/filter" 30 "trpc.group/trpc-go/trpc-go/naming/registry" 31 "trpc.group/trpc-go/trpc-go/restful" 32 "trpc.group/trpc-go/trpc-go/server" 33 "trpc.group/trpc-go/trpc-go/transport" 34 35 _ "trpc.group/trpc-go/trpc-go" 36 ) 37 38 // go test -v -coverprofile=cover.out 39 // go tool cover -func=cover.out 40 41 var ctx = context.Background() 42 43 type fakeHandler struct { 44 } 45 46 func (s *fakeHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) { 47 return req, nil 48 } 49 50 func TestOptions(t *testing.T) { 51 52 opts := &server.Options{} 53 transportOpts := &transport.ListenServeOptions{} 54 55 // WithServiceName 56 o := server.WithServiceName("trpc.test.helloworld") 57 o(opts) 58 assert.Equal(t, opts.ServiceName, "trpc.test.helloworld") 59 60 o = server.WithNamespace("Development") 61 o(opts) 62 assert.Equal(t, opts.Namespace, "Development") 63 64 o = server.WithEnvName("formal") 65 o(opts) 66 assert.Equal(t, opts.EnvName, "formal") 67 68 o = server.WithSetName("a.b.c") 69 o(opts) 70 assert.Equal(t, opts.SetName, "a.b.c") 71 72 // WithDisableRequestTimeout 73 assert.Equal(t, opts.DisableRequestTimeout, false) // false by default 74 o = server.WithDisableRequestTimeout(true) 75 o(opts) 76 assert.Equal(t, opts.DisableRequestTimeout, true) 77 78 // WithAddress 79 o = server.WithAddress("127.0.0.1:8080") 80 o(opts) 81 for _, o := range opts.ServeOptions { 82 o(transportOpts) 83 } 84 assert.Equal(t, transportOpts.Address, "127.0.0.1:8080") 85 assert.Equal(t, opts.Address, "127.0.0.1:8080") 86 87 // WithNetwork 88 o = server.WithNetwork("tcp") 89 o(opts) 90 for _, o := range opts.ServeOptions { 91 o(transportOpts) 92 } 93 assert.Equal(t, transportOpts.Network, "tcp") 94 95 lis, _ := net.Listen("tcp", "127.0.0.1:8080") 96 o = server.WithListener(lis) 97 o(opts) 98 for _, o := range opts.ServeOptions { 99 o(transportOpts) 100 } 101 assert.Equal(t, transportOpts.Listener, lis) 102 if lis != nil { 103 lis.Close() 104 } 105 106 o = server.WithTLS("server.crt", "server.key", "ca.pem") 107 o(opts) 108 for _, o := range opts.ServeOptions { 109 o(transportOpts) 110 } 111 assert.Equal(t, transportOpts.TLSCertFile, "server.crt") 112 assert.Equal(t, transportOpts.TLSKeyFile, "server.key") 113 } 114 115 func TestMoreOptions(t *testing.T) { 116 // WithHandler 117 h := &fakeHandler{} 118 o := server.WithHandler(h) 119 opts := &server.Options{} 120 transportOpts := &transport.ListenServeOptions{} 121 o(opts) 122 for _, o := range opts.ServeOptions { 123 o(transportOpts) 124 } 125 assert.Equal(t, transportOpts.Handler, h) 126 127 // WithTimeout 128 o = server.WithTimeout(time.Second) 129 o(opts) 130 for _, o := range opts.ServeOptions { 131 o(transportOpts) 132 } 133 assert.Equal(t, opts.Timeout, time.Second) 134 135 // WithTransport 136 o = server.WithTransport(transport.DefaultServerTransport) 137 o(opts) 138 assert.Equal(t, opts.Transport, transport.DefaultServerTransport) 139 140 // register ServerTransport 141 transport.RegisterServerTransport("trpc", transport.DefaultServerTransport) 142 // WithProtocol 143 o = server.WithProtocol("trpc") 144 o(opts) 145 for _, o := range opts.ServeOptions { 146 o(transportOpts) 147 } 148 assert.NotEqual(t, opts.Codec, nil) 149 150 o = server.WithProtocol("fake") 151 o(opts) 152 for _, o := range opts.ServeOptions { 153 o(transportOpts) 154 } 155 156 o = server.WithCurrentSerializationType(codec.SerializationTypeNoop) 157 o(opts) 158 assert.Equal(t, opts.CurrentSerializationType, codec.SerializationTypeNoop) 159 160 o = server.WithCurrentCompressType(codec.CompressTypeSnappy) 161 o(opts) 162 assert.Equal(t, opts.CurrentCompressType, codec.CompressTypeSnappy) 163 164 // WithFilter 165 o = server.WithFilter(filter.NoopServerFilter) 166 o(opts) 167 assert.Equal(t, len(opts.Filters), 1) 168 169 // WithFilters 170 o = server.WithFilters([]filter.ServerFilter{filter.NoopServerFilter}) 171 o(opts) 172 assert.Equal(t, len(opts.Filters), 2) 173 174 // WithStreamFilter 175 sf1 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error { 176 return nil 177 } 178 o = server.WithStreamFilter(sf1) 179 o(opts) 180 assert.Equal(t, 1, len(opts.StreamFilters)) 181 182 // WithStreamFilters 183 sf2 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error { 184 return nil 185 } 186 o = server.WithStreamFilters(sf1, sf2) 187 o(opts) 188 assert.Equal(t, 3, len(opts.StreamFilters)) 189 190 // WithRegistry 191 o = server.WithRegistry(registry.DefaultRegistry) 192 o(opts) 193 assert.Equal(t, registry.DefaultRegistry, opts.Registry) 194 195 // WithServerAsync 196 o = server.WithServerAsync(true) 197 o(opts) 198 for _, o := range opts.ServeOptions { 199 o(transportOpts) 200 } 201 assert.Equal(t, transportOpts.ServerAsync, true) 202 203 // WithMaxRoutines 204 server.WithMaxRoutines(100)(opts) 205 // WithWritev 206 server.WithWritev(true)(opts) 207 for _, o := range opts.ServeOptions { 208 o(transportOpts) 209 } 210 assert.Equal(t, transportOpts.Writev, true) 211 212 // WithMaxRoutines 213 o = server.WithMaxRoutines(100) 214 o(opts) 215 for _, o := range opts.ServeOptions { 216 o(transportOpts) 217 } 218 assert.Equal(t, transportOpts.Routines, 100) 219 220 // WithStreamTransport 221 o = server.WithStreamTransport(transport.DefaultServerStreamTransport) 222 o(opts) 223 assert.Equal(t, opts.StreamTransport, transport.DefaultServerStreamTransport) 224 225 // WithCloseWaitTime 226 o = server.WithCloseWaitTime(0 * time.Millisecond) 227 o(opts) 228 for _, o := range opts.ServeOptions { 229 o(transportOpts) 230 } 231 assert.Equal(t, opts.CloseWaitTime, 0*time.Millisecond) 232 233 // WithMaxCloseWaitTime 234 o = server.WithMaxCloseWaitTime(100 * time.Millisecond) 235 o(opts) 236 for _, o := range opts.ServeOptions { 237 o(transportOpts) 238 } 239 assert.Equal(t, opts.MaxCloseWaitTime, 100*time.Millisecond) 240 241 // WithRESTOptions 242 o1 := server.WithRESTOptions(restful.WithServiceName("name a")) 243 o2 := server.WithRESTOptions(restful.WithServiceName("name b")) 244 o1(opts) 245 o2(opts) 246 restOptions := &restful.Options{} 247 for _, o := range opts.RESTOptions { 248 o(restOptions) 249 } 250 assert.Equal(t, 2, len(opts.RESTOptions)) 251 assert.Equal(t, "name b", restOptions.ServiceName) 252 253 // WithIdleTimeout 254 idleTimeout := time.Second 255 o = server.WithIdleTimeout(idleTimeout) 256 o(opts) 257 for _, o := range opts.ServeOptions { 258 o(transportOpts) 259 } 260 assert.Equal(t, transportOpts.IdleTimeout, idleTimeout) 261 262 // WithDisableKeepAlives 263 disableKeepAlives := true 264 o = server.WithDisableKeepAlives(disableKeepAlives) 265 o(opts) 266 for _, o := range opts.ServeOptions { 267 o(transportOpts) 268 } 269 assert.Equal(t, disableKeepAlives, transportOpts.DisableKeepAlives) 270 271 // WithMaxWindowSize 272 var maxWindowSize uint32 = 100 273 o = server.WithMaxWindowSize(maxWindowSize) 274 o(opts) 275 assert.Equal(t, maxWindowSize, opts.MaxWindowSize) 276 } 277 278 func TestWithNamedFilter(t *testing.T) { 279 var ( 280 filterNames []string 281 filters filter.ServerChain 282 283 sf = func( 284 ctx context.Context, 285 req interface{}, 286 next filter.ServerHandleFunc, 287 ) (rsp interface{}, err error) { 288 return next(ctx, req) 289 } 290 ) 291 for i := 0; i < 10; i++ { 292 filterNames = append(filterNames, fmt.Sprintf("filter-%d", i)) 293 filters = append(filters, sf) 294 } 295 296 var os []server.Option 297 for i := range filters { 298 os = append(os, server.WithNamedFilter(filterNames[i], filters[i])) 299 } 300 301 options := &server.Options{} 302 for _, o := range os { 303 o(options) 304 } 305 require.Equal(t, filterNames, options.FilterNames) 306 require.Equal(t, len(filters), len(options.Filters)) 307 for i := range filters { 308 require.Equal( 309 t, 310 runtime.FuncForPC(reflect.ValueOf(filters[i]).Pointer()).Name(), 311 runtime.FuncForPC(reflect.ValueOf(options.Filters[i]).Pointer()).Name(), 312 ) 313 } 314 }