trpc.group/trpc-go/trpc-go@v1.0.3/server/service_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package server_test 15 16 import ( 17 "context" 18 "errors" 19 "math/rand" 20 "net" 21 "os" 22 "testing" 23 "time" 24 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 28 "trpc.group/trpc-go/trpc-go/client" 29 "trpc.group/trpc-go/trpc-go/codec" 30 "trpc.group/trpc-go/trpc-go/errs" 31 "trpc.group/trpc-go/trpc-go/filter" 32 "trpc.group/trpc-go/trpc-go/log" 33 "trpc.group/trpc-go/trpc-go/naming/registry" 34 "trpc.group/trpc-go/trpc-go/restful" 35 "trpc.group/trpc-go/trpc-go/server" 36 pb "trpc.group/trpc-go/trpc-go/testdata/trpc/helloworld" 37 "trpc.group/trpc-go/trpc-go/transport" 38 ) 39 40 func init() { 41 rand.Seed(time.Now().Unix()) 42 } 43 44 // go test -v 45 type fakeTransport struct { 46 } 47 48 func (s *fakeTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error { 49 lsopts := &transport.ListenServeOptions{} 50 for _, opt := range opts { 51 opt(lsopts) 52 } 53 54 go func() { 55 lsopts.Handler.Handle(ctx, []byte("normal-request")) 56 lsopts.Handler.Handle(ctx, []byte("stream")) 57 lsopts.Handler.Handle(ctx, []byte("no-rpc-name")) 58 lsopts.Handler.Handle(ctx, []byte("decode-error")) 59 lsopts.Handler.Handle(ctx, []byte("encode-error")) 60 lsopts.Handler.Handle(ctx, []byte("handle-timeout")) 61 lsopts.Handler.Handle(ctx, []byte("no-response")) 62 lsopts.Handler.Handle(ctx, []byte("business-fail")) 63 lsopts.Handler.Handle(ctx, []byte("handle-panic")) 64 lsopts.Handler.Handle(ctx, []byte("compress-error")) 65 lsopts.Handler.Handle(ctx, []byte("decompress-error")) 66 lsopts.Handler.Handle(ctx, []byte("unmarshal-error")) 67 lsopts.Handler.Handle(ctx, []byte("marshal-error")) 68 ctx := context.Background() 69 ctx, msg := codec.WithNewMessage(ctx) 70 msg.WithServerRspErr(errors.New("connection is tryClose ")) 71 lsopts.Handler.Handle(ctx, nil) 72 73 }() 74 75 return nil 76 } 77 78 type fakeCodec struct { 79 } 80 81 func (c *fakeCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) { 82 req := string(reqBuf) 83 84 if req == "stream" { 85 msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHi") 86 return reqBuf, nil 87 } 88 if req != "no-rpc-name" { 89 msg.WithServerRPCName("/trpc.test.helloworld.Greeter/SayHello") 90 } 91 if req == "decode-error" { 92 return nil, errors.New("server decode request fail") 93 } 94 msg.WithRequestTimeout(time.Second) 95 msg.WithSerializationType(codec.SerializationTypeNoop) 96 log.Infof("fakeCodec ==> req[%v]", req) 97 return reqBuf, nil 98 } 99 100 func (c *fakeCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) { 101 rsp := string(rspBody) 102 if rsp == "encode-error" { 103 return nil, errors.New("server encode response fail") 104 } 105 return rspBody, nil 106 } 107 108 func (c *fakeCodec) Compress(in []byte) (out []byte, err error) { 109 rsp := string(in) 110 if rsp == "compress-error" { 111 return nil, errors.New("server compress fail") 112 } 113 return in, nil 114 } 115 116 func (c *fakeCodec) Decompress(in []byte) (out []byte, err error) { 117 req := string(in) 118 if req == "decompress-error" { 119 return nil, errors.New("server decompress fail") 120 } 121 return in, nil 122 } 123 124 func (c *fakeCodec) Unmarshal(reqBuf []byte, reqBody interface{}) error { 125 req := string(reqBuf) 126 if req == "unmarshal-error" { 127 return errors.New("server unmarshal fail") 128 } 129 return codec.Unmarshal(codec.SerializationTypeNoop, reqBuf, reqBody) 130 } 131 132 func (c *fakeCodec) Marshal(rspBody interface{}) (rspBuf []byte, err error) { 133 if rsp, ok := rspBody.(*codec.Body); ok { 134 if string(rsp.Data) == "marshal-error" { 135 return nil, errors.New("server marshal fail") 136 } 137 } 138 return codec.Marshal(codec.SerializationTypeNoop, rspBody) 139 } 140 141 type fakeRegistry struct { 142 } 143 144 func (r *fakeRegistry) Register(service string, opt ...registry.Option) error { 145 return nil 146 } 147 func (r *fakeRegistry) Deregister(service string) error { 148 return nil 149 } 150 151 func TestService(t *testing.T) { 152 codec.Register("fake", &fakeCodec{}, nil) 153 // register the fake codec 154 codec.RegisterCompressor(930, &fakeCodec{}) 155 codec.RegisterSerializer(1930, &fakeCodec{}) 156 157 // 1.codec not set,transport will cause error. 158 service := server.New(server.WithServiceName("trpc.test.helloworld.Greeter"), 159 server.WithTransport(&fakeTransport{}), 160 server.WithRegistry(®istry.NoopRegistry{})) 161 162 impl := &GreeterServerImpl{} 163 err := service.Register(&GreeterServerServiceDesc, impl) 164 assert.Nil(t, err) 165 166 go func() { 167 _ = service.Serve() 168 }() 169 // closing service will not return error even if registry fails. 170 err = service.Close(nil) 171 assert.Nil(t, err) 172 173 // 2. valid service registration 174 service = server.New(server.WithProtocol("fake"), 175 server.WithServiceName("trpc.test.helloworld.Greeter"), 176 server.WithTransport(&fakeTransport{}), 177 server.WithRegistry(&fakeRegistry{}), 178 server.WithCurrentSerializationType(1930), 179 server.WithCurrentCompressType(930), 180 server.WithCloseWaitTime(100*time.Millisecond), 181 server.WithMaxCloseWaitTime(200*time.Millisecond)) 182 err = service.Register(&GreeterServerServiceDesc, impl) 183 assert.Nil(t, err) 184 185 // RESTful router should exist 186 assert.NotNil(t, restful.GetRouter("trpc.test.helloworld.Greeter")) 187 188 go func() { 189 _ = service.Serve() 190 }() 191 time.Sleep(time.Second * 2) 192 err = service.Close(nil) 193 assert.Nil(t, err) 194 } 195 196 // TestServiceFail tests failures of request handling. 197 func TestServiceFail(t *testing.T) { 198 199 codec.Register("fake", &fakeCodec{}, nil) 200 service := server.New(server.WithProtocol("fake"), 201 server.WithServiceName("trpc.test.helloworld.Greeter"), 202 server.WithTransport(&fakeTransport{}), 203 server.WithRegistry(&fakeRegistry{}), 204 ) 205 206 impl := &GreeterServerImpl{} 207 err := service.Register(&GreeterServerServiceDescFail, impl) 208 assert.Nil(t, err) 209 go func() { 210 service.Serve() 211 }() 212 213 time.Sleep(time.Second * 2) 214 } 215 216 // TestServiceMethodNameUniqueness tests method name uniqueness 217 func TestServiceMethodNameUniqueness(t *testing.T) { 218 codec.Register("fake", &fakeCodec{}, nil) 219 service := server.New(server.WithProtocol("fake"), 220 server.WithServiceName("trpc.test.helloworld.Greeter"), 221 server.WithTransport(&fakeTransport{}), 222 server.WithRegistry(&fakeRegistry{}), 223 ) 224 225 impl := &GreeterServerImpl{} 226 err := service.Register(&GreeterServerServiceDescFail, impl) 227 assert.Nil(t, err) 228 229 err = service.Register(&GreeterServerServiceDescFail, impl) 230 assert.NotNil(t, err) 231 } 232 233 func TestServiceTimeout(t *testing.T) { 234 require.Nil(t, os.Setenv(transport.EnvGraceRestart, "")) 235 t.Run("server timeout", func(t *testing.T) { 236 addr, stop := startService(t, &GreeterServerImpl{}, 237 server.WithTimeout(time.Second), 238 server.WithFilter( 239 func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) { 240 return nil, errs.NewFrameError(errs.RetServerTimeout, "") 241 })) 242 defer stop() 243 244 c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr)) 245 _, err := c.SayHello(context.Background(), &pb.HelloRequest{}) 246 require.NotNil(t, err) 247 e, ok := err.(*errs.Error) 248 require.True(t, ok) 249 require.EqualValues(t, int32(errs.RetServerTimeout), e.Code) 250 }) 251 t.Run("client full link timeout is converted to server timeout", 252 func(t *testing.T) { 253 addr, stop := startService(t, 254 &Greeter{ 255 sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) { 256 return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "") 257 }}, 258 server.WithTimeout(time.Second)) 259 defer stop() 260 261 c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr)) 262 _, err := c.SayHello(ctx, &pb.HelloRequest{}) 263 require.NotNil(t, err) 264 e, ok := err.(*errs.Error) 265 require.True(t, ok) 266 require.Equal(t, errs.ErrorTypeCalleeFramework, e.Type) 267 require.EqualValues(t, int32(errs.RetServerTimeout), e.Code) 268 }) 269 t.Run("client full link timeout is converted to server full link timeout, and then dropped", 270 func(t *testing.T) { 271 addr, stop := startService(t, 272 &Greeter{ 273 sayHello: func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) { 274 return nil, errs.NewFrameError(errs.RetClientFullLinkTimeout, "") 275 }}, 276 server.WithTimeout(time.Second*2)) 277 defer stop() 278 279 c := pb.NewGreeterClientProxy(client.WithTarget("ip://" + addr)) 280 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 281 defer cancel() 282 _, err := c.SayHello(ctx, &pb.HelloRequest{}) 283 require.NotNil(t, err) 284 e, ok := err.(*errs.Error) 285 require.True(t, ok) 286 require.Equal(t, errs.ErrorTypeFramework, e.Type) 287 require.EqualValues(t, int32(errs.RetClientFullLinkTimeout), e.Code, 288 "server full link timeout is dropped, and client should receive a client timeout error") 289 }) 290 } 291 292 func TestServiceUDP(t *testing.T) { 293 addr := "127.0.0.1:10000" 294 s := server.New([]server.Option{ 295 server.WithNetwork("udp"), 296 server.WithProtocol("trpc"), 297 server.WithAddress(addr), 298 server.WithCurrentSerializationType(codec.SerializationTypeNoop), 299 }...) 300 require.Nil(t, s.Register(&GreeterServerServiceDesc, &GreeterServerImpl{})) 301 go s.Serve() 302 time.Sleep(time.Millisecond * 200) 303 304 c := pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr), client.WithNetwork("udp")) 305 _, err := c.SayHello(context.Background(), &pb.HelloRequest{}) 306 require.Nil(t, err) 307 } 308 309 func TestCloseWaitTime(t *testing.T) { 310 startService := func(opts ...server.Option) (chan struct{}, func()) { 311 received, done := make(chan struct{}), make(chan struct{}) 312 addr, stop := startService(t, &Greeter{}, append([]server.Option{server.WithFilter( 313 func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) { 314 received <- struct{}{} 315 <-done 316 return nil, errors.New("must fail") 317 })}, opts...)...) 318 go func() { 319 _, _ = pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr)). 320 SayHello(context.Background(), &pb.HelloRequest{}) 321 }() 322 <-received 323 return done, stop 324 } 325 t.Run("active requests feature is not enabled on missing MaxCloseWaitTime", func(t *testing.T) { 326 done, stop := startService() 327 defer close(done) 328 start := time.Now() 329 stop() 330 require.Less(t, time.Since(start), time.Millisecond*100) 331 }) 332 t.Run("total wait time should not significantly greater than MaxCloseWaitTime", func(t *testing.T) { 333 const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second 334 done, stop := startService( 335 server.WithMaxCloseWaitTime(maxCloseWaitTime), 336 server.WithCloseWaitTime(closeWaitTime)) 337 defer close(done) 338 start := time.Now() 339 stop() 340 require.WithinRange(t, time.Now(), 341 // 300ms comes from the internal implementation when close service 342 start.Add(maxCloseWaitTime).Add(time.Millisecond*300), 343 start.Add(maxCloseWaitTime).Add(time.Millisecond*500)) 344 }) 345 t.Run("total wait time is at least CloseWaitTime", func(t *testing.T) { 346 const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second 347 done, stop := startService( 348 server.WithMaxCloseWaitTime(maxCloseWaitTime), 349 server.WithCloseWaitTime(closeWaitTime)) 350 start := time.Now() 351 time.AfterFunc(closeWaitTime/2, func() { close(done) }) 352 stop() 353 require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(closeWaitTime+time.Millisecond*100)) 354 }) 355 t.Run("no active request before MaxCloseWaitTime", func(t *testing.T) { 356 const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second 357 done, stop := startService( 358 server.WithMaxCloseWaitTime(maxCloseWaitTime), 359 server.WithCloseWaitTime(closeWaitTime)) 360 start := time.Now() 361 time.AfterFunc((closeWaitTime+maxCloseWaitTime)/2, func() { close(done) }) 362 stop() 363 require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(maxCloseWaitTime)) 364 }) 365 t.Run("no active request before service timeout", func(t *testing.T) { 366 const closeWaitTime, maxCloseWaitTime, timeout = time.Millisecond * 500, time.Second, time.Second 367 done, stop := startService( 368 server.WithMaxCloseWaitTime(maxCloseWaitTime), 369 server.WithCloseWaitTime(closeWaitTime), 370 server.WithTimeout(timeout)) 371 start := time.Now() 372 time.AfterFunc(maxCloseWaitTime+time.Millisecond*100, func() { close(done) }) 373 stop() 374 require.WithinRange(t, time.Now(), start.Add(maxCloseWaitTime+time.Millisecond*100), start.Add(maxCloseWaitTime+timeout)) 375 }) 376 } 377 378 func startService(t *testing.T, gs GreeterServer, opts ...server.Option) (addr string, stop func()) { 379 l, err := net.Listen("tcp", "0.0.0.0:0") 380 require.Nil(t, err) 381 382 s := server.New(append(append( 383 []server.Option{ 384 server.WithNetwork("tcp"), 385 server.WithProtocol("trpc"), 386 }, opts...), 387 server.WithListener(l), 388 )...) 389 require.Nil(t, s.Register(&GreeterServerServiceDesc, gs)) 390 391 errCh := make(chan error) 392 go func() { errCh <- s.Serve() }() 393 select { 394 case err := <-errCh: 395 require.FailNow(t, "serve failed", err) 396 case <-time.After(time.Millisecond * 200): 397 } 398 return l.Addr().String(), func() { s.Close(nil) } 399 } 400 401 func TestGetStreamFilter(t *testing.T) { 402 expectedErr := errors.New("expected error") 403 testFilter := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error { 404 return expectedErr 405 } 406 server.RegisterStreamFilter("testFilter", testFilter) 407 filter := server.GetStreamFilter("testFilter") 408 err := filter(nil, &server.StreamServerInfo{}, nil) 409 assert.Equal(t, expectedErr, err) 410 } 411 412 type Greeter struct { 413 sayHello func(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) 414 } 415 416 func (g *Greeter) SayHello(ctx context.Context, req *codec.Body) (rsp *codec.Body, err error) { 417 return g.sayHello(ctx, req) 418 } 419 420 func (*Greeter) SayHi(gs Greeter_SayHiServer) error { 421 return nil 422 } 423 424 func TestStreamFilterChainFilter(t *testing.T) { 425 ch := make(chan int, 10) 426 sf1 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error { 427 ch <- 1 428 err := handler(ss) 429 ch <- 5 430 return err 431 } 432 sf2 := func(ss server.Stream, info *server.StreamServerInfo, handler server.StreamHandler) error { 433 ch <- 2 434 err := handler(ss) 435 ch <- 4 436 return err 437 } 438 option := server.WithStreamFilters(sf1, sf2) 439 options := server.Options{} 440 option(&options) 441 _ = options.StreamFilters.Filter(nil, nil, func(stream server.Stream) error { 442 ch <- 3 443 return nil 444 }) 445 assert.Equal(t, 1, <-ch) 446 assert.Equal(t, 2, <-ch) 447 assert.Equal(t, 3, <-ch) 448 assert.Equal(t, 4, <-ch) 449 assert.Equal(t, 5, <-ch) 450 }