trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/server_transport_tcp_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 package tnet_test 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "net" 25 "os" 26 "strconv" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/stretchr/testify/assert" 32 "trpc.group/trpc-go/tnet" 33 34 trpc "trpc.group/trpc-go/trpc-go" 35 "trpc.group/trpc-go/trpc-go/codec" 36 "trpc.group/trpc-go/trpc-go/transport" 37 tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" 38 ) 39 40 var ( 41 port uint64 = 9000 42 helloWorld = []byte("helloworld") 43 ) 44 45 func TestServerTCP_ListenAndServe(t *testing.T) { 46 startServerTest( 47 t, 48 defaultServerHandle, 49 nil, 50 func(addr string) { 51 rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(addr)) 52 assert.Nil(t, err) 53 assert.Equal(t, helloWorld, rsp) 54 }, 55 ) 56 } 57 58 func TestServerTCP_Asyn(t *testing.T) { 59 startServerTest( 60 t, 61 defaultServerHandle, 62 []transport.ListenServeOption{transport.WithServerAsync(true)}, 63 func(addr string) { 64 rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(addr)) 65 assert.Nil(t, err) 66 assert.Equal(t, helloWorld, rsp) 67 }, 68 ) 69 } 70 71 func TestServerTCP_CustomizedFramerCopyFrame(t *testing.T) { 72 startServerTest( 73 t, 74 func(ctx context.Context, req []byte) ([]byte, error) { 75 return req, nil 76 }, 77 []transport.ListenServeOption{ 78 transport.WithServerFramerBuilder(&reuseBufferFramerBuilder{}), 79 transport.WithServerAsync(true), 80 }, 81 func(addr string) { 82 req := helloWorld 83 ctx, _ := codec.EnsureMessage(context.Background()) 84 reqbytes, err := (&emptyClientCodec{}).Encode( 85 codec.Message(ctx), 86 req, 87 ) 88 assert.Nil(t, err) 89 90 cliOpts := []transport.RoundTripOption{ 91 transport.WithDialAddress(addr), 92 transport.WithDialNetwork("tcp"), 93 transport.WithClientFramerBuilder(&reuseBufferFramerBuilder{}), 94 transport.WithDialTimeout(5 * time.Second), 95 } 96 clientTrans := transport.NewClientTransport() 97 rspbytes, err := clientTrans.RoundTrip( 98 ctx, 99 reqbytes, 100 cliOpts..., 101 ) 102 assert.Nil(t, err) 103 104 rsp, err := (&emptyClientCodec{}).Decode( 105 codec.Message(ctx), 106 rspbytes, 107 ) 108 assert.Nil(t, err) 109 assert.Equal(t, helloWorld, rsp) 110 }, 111 ) 112 } 113 114 func TestServerTCP_UserDefineListener(t *testing.T) { 115 serverAddr := getAddr() 116 ln, err := tnet.Listen("tcp", serverAddr) 117 assert.Nil(t, err) 118 startServerTest( 119 t, 120 defaultServerHandle, 121 []transport.ListenServeOption{transport.WithListener(ln)}, 122 func(_ string) { 123 rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(serverAddr)) 124 assert.Nil(t, err) 125 assert.Equal(t, helloWorld, rsp) 126 }, 127 ) 128 } 129 130 func TestServerTCP_ErrorCases(t *testing.T) { 131 s := tnettrans.NewServerTransport() 132 133 // Without framerBuilder 134 serveOpts := getListenServeOption( 135 transport.WithServerFramerBuilder(nil), 136 ) 137 err := s.ListenAndServe(context.Background(), serveOpts...) 138 assert.NotNil(t, err) 139 140 // Unsupported network type 141 serveOpts = getListenServeOption( 142 transport.WithListenNetwork("ip"), 143 ) 144 err = s.ListenAndServe(context.Background(), serveOpts...) 145 assert.NotNil(t, err) 146 } 147 148 func TestServerTCP_HandleErr(t *testing.T) { 149 startServerTest( 150 t, 151 errServerHandle, 152 nil, 153 func(addr string) { 154 _, err := gonetRequest(context.Background(), transport.WithDialAddress(addr)) 155 fmt.Println(err) 156 assert.NotNil(t, err) 157 }, 158 ) 159 } 160 161 func TestServerTCP_IdleTimeout(t *testing.T) { 162 startServerTest( 163 t, 164 defaultServerHandle, 165 []transport.ListenServeOption{transport.WithServerIdleTimeout(time.Second)}, 166 func(addr string) { 167 cliconn, err := tnet.DialTCP("tcp", addr, 0) 168 assert.Nil(t, err) 169 _, err = cliconn.Write([]byte("0")) 170 assert.Nil(t, err) 171 172 // sleep to make sure ListenAndServe run into onRequest() 173 time.Sleep(2 * time.Second) 174 _, err = cliconn.Write([]byte("0")) 175 assert.NotNil(t, err) 176 }, 177 ) 178 179 } 180 181 func TestServerTCP_WriteFail(t *testing.T) { 182 ch := make(chan struct{}, 1) 183 var isHandled bool 184 startServerTest( 185 t, 186 func(ctx context.Context, req []byte) ([]byte, error) { 187 isHandled = true 188 <-ch 189 return nil, nil 190 }, 191 []transport.ListenServeOption{transport.WithServerAsync(true)}, 192 func(addr string) { 193 ctx, _ := codec.EnsureMessage(context.Background()) 194 req, err := trpc.DefaultClientCodec.Encode(codec.Message(ctx), helloWorld) 195 assert.Nil(t, err) 196 197 cliconn, err := tnet.DialTCP("tcp", addr, 0) 198 assert.Nil(t, err) 199 _, err = cliconn.Write(req) 200 assert.Nil(t, err) 201 202 // sleep to make sure server received data 203 time.Sleep(50 * time.Millisecond) 204 cliconn.Close() 205 // notify server write back data, but server will fail, because connection is closed 206 ch <- struct{}{} 207 _, err = cliconn.ReadN(1) 208 assert.NotNil(t, err) 209 // make sure server run into handle 210 assert.True(t, isHandled) 211 }, 212 ) 213 } 214 215 func TestServerTCP_PassedListener(t *testing.T) { 216 serverAddr := getAddr() 217 listener, err := net.Listen("tcp", serverAddr) 218 assert.Nil(t, err) 219 220 transport.SaveListener(listener) 221 fds := transport.GetListenersFds() 222 var fd int 223 for _, f := range fds { 224 if f.Address == serverAddr { 225 fd = int(f.Fd) 226 } 227 } 228 229 os.Setenv(transport.EnvGraceRestart, "1") 230 os.Setenv(transport.EnvGraceFirstFd, strconv.Itoa(fd)) 231 os.Setenv(transport.EnvGraceRestartFdNum, "1") 232 233 defer func() { 234 os.Setenv(transport.EnvGraceRestart, "0") 235 os.Setenv(transport.EnvGraceFirstFd, "0") 236 os.Setenv(transport.EnvGraceRestartFdNum, "0") 237 }() 238 239 startServerTest( 240 t, 241 defaultServerHandle, 242 []transport.ListenServeOption{transport.WithListenAddress(serverAddr)}, 243 func(_ string) { 244 rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(serverAddr)) 245 assert.Nil(t, err) 246 assert.Equal(t, helloWorld, rsp) 247 }, 248 ) 249 } 250 251 func TestServerTCP_ClientWrongReq(t *testing.T) { 252 startServerTest( 253 t, 254 defaultServerHandle, 255 nil, 256 func(addr string) { 257 cliconn, err := tnet.DialTCP("tcp", addr, 0) 258 assert.Nil(t, err) 259 _, err = cliconn.Write([]byte("1234567890123456")) 260 assert.Nil(t, err) 261 262 // sleep to make sure ListenAndServe run into onRequest() 263 time.Sleep(50 * time.Millisecond) 264 err = cliconn.Close() 265 assert.Nil(t, err) 266 }, 267 ) 268 } 269 270 func TestServerTCP_SendAndClose(t *testing.T) { 271 addr := getAddr() 272 s := tnettrans.NewServerTransport() 273 serveOpts := getListenServeOption( 274 transport.WithListenAddress(addr), 275 transport.WithServerAsync(true), 276 ) 277 err := s.ListenAndServe(context.Background(), serveOpts...) 278 assert.Nil(t, err) 279 280 cliconn, err := tnet.DialTCP("tcp", addr, 0) 281 assert.Nil(t, err) 282 cliAddr := cliconn.LocalAddr() 283 284 time.Sleep(50 * time.Millisecond) 285 streamTransport, ok := s.(transport.ServerStreamTransport) 286 assert.True(t, ok) 287 ctx, msg := codec.EnsureMessage(context.Background()) 288 msg.WithRemoteAddr(cliAddr) 289 svrAddr, err := net.ResolveTCPAddr("tcp", addr) 290 assert.Nil(t, err) 291 msg.WithLocalAddr(svrAddr) 292 err = streamTransport.Send(ctx, helloWorld) 293 assert.Nil(t, err) 294 295 b := make([]byte, len(helloWorld)) 296 cliconn.Read(b) 297 assert.Equal(t, b, helloWorld) 298 299 streamTransport.Close(ctx) 300 err = streamTransport.Send(ctx, helloWorld) 301 assert.NotNil(t, err) 302 } 303 304 func TestServerTCP_TLS(t *testing.T) { 305 startServerTest( 306 t, 307 defaultServerHandle, 308 []transport.ListenServeOption{transport.WithServeTLS("../../testdata/server.crt", "../../testdata/server.key", "../../testdata/ca.pem")}, 309 func(addr string) { 310 rsp, err := gonetRequest( 311 context.Background(), 312 transport.WithDialAddress(addr), 313 transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "../../testdata/ca.pem", "localhost"), 314 ) 315 assert.Nil(t, err) 316 assert.Equal(t, helloWorld, rsp) 317 318 rsp, err = gonetRequest( 319 context.Background(), 320 transport.WithDialAddress(addr), 321 transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "none", ""), 322 ) 323 assert.Nil(t, err) 324 assert.Equal(t, helloWorld, rsp) 325 }, 326 ) 327 } 328 329 func TestUDP(t *testing.T) { 330 // UDP is not supported, but it will switch to gonet default transport to serve. 331 startServerTest( 332 t, 333 defaultServerHandle, 334 []transport.ListenServeOption{transport.WithListenNetwork("tcp,udp")}, 335 func(addr string) { 336 rsp, err := gonetRequest( 337 context.Background(), 338 transport.WithDialAddress(addr), 339 transport.WithDialNetwork("udp")) 340 assert.Nil(t, err) 341 assert.Equal(t, helloWorld, rsp) 342 343 rsp, err = gonetRequest( 344 context.Background(), 345 transport.WithDialAddress(addr), 346 transport.WithDialNetwork("tcp")) 347 assert.Nil(t, err) 348 assert.Equal(t, helloWorld, rsp) 349 }, 350 ) 351 } 352 353 func TestUnix(t *testing.T) { 354 // Unix socket is not supported, but it will switch to gonet default transport to serve. 355 myAddr := "/tmp/server.sock" 356 os.Remove(myAddr) 357 startServerTest( 358 t, 359 defaultServerHandle, 360 []transport.ListenServeOption{ 361 transport.WithListenNetwork("unix"), 362 transport.WithListenAddress(myAddr), 363 }, 364 func(_ string) { 365 rsp, err := gonetRequest( 366 context.Background(), 367 transport.WithDialAddress(myAddr), 368 transport.WithDialNetwork("unix")) 369 assert.Nil(t, err) 370 assert.Equal(t, helloWorld, rsp) 371 }, 372 ) 373 } 374 375 func getListenServeOption(opts ...transport.ListenServeOption) []transport.ListenServeOption { 376 lsopts := []transport.ListenServeOption{ 377 transport.WithServerFramerBuilder(trpc.DefaultFramerBuilder), 378 transport.WithListenNetwork("tcp"), 379 transport.WithHandler(newUserDefineHandler(defaultServerHandle)), 380 transport.WithServerIdleTimeout(5 * time.Second), 381 } 382 lsopts = append(lsopts, opts...) 383 return lsopts 384 } 385 386 func defaultServerHandle(ctx context.Context, req []byte) (rsp []byte, err error) { 387 msg := codec.Message(ctx) 388 reqdata, err := trpc.DefaultServerCodec.Decode(msg, req) 389 if err != nil { 390 return nil, err 391 } 392 rspdata := make([]byte, len(reqdata)) 393 copy(rspdata, reqdata) 394 rsp, err = trpc.DefaultServerCodec.Encode(msg, rspdata) 395 return rsp, err 396 } 397 398 func errServerHandle(ctx context.Context, req []byte) (rsp []byte, err error) { 399 return nil, errors.New("mock error") 400 } 401 402 type userDefineHandler struct { 403 handleFunc func(context.Context, []byte) ([]byte, error) 404 } 405 406 func newUserDefineHandler(f func(context.Context, []byte) ([]byte, error)) *userDefineHandler { 407 return &userDefineHandler{handleFunc: f} 408 } 409 410 func (uh *userDefineHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) { 411 return uh.handleFunc(ctx, req) 412 } 413 414 func startServerTest( 415 t *testing.T, 416 serverHandle func(ctx context.Context, req []byte) ([]byte, error), 417 svrCustomOpts []transport.ListenServeOption, 418 clientHandle func(addr string), 419 ) { 420 addr := getAddr() 421 s := tnettrans.NewServerTransport( 422 tnettrans.WithKeepAlivePeriod(15*time.Second), 423 tnettrans.WithReusePort(true), 424 ) 425 handler := newUserDefineHandler(func(ctx context.Context, req []byte) ([]byte, error) { 426 return serverHandle(ctx, req) 427 }) 428 serveOpts := getListenServeOption( 429 transport.WithListenAddress(addr), 430 transport.WithHandler(handler), 431 ) 432 serveOpts = append(serveOpts, svrCustomOpts...) 433 err := s.ListenAndServe(context.Background(), serveOpts...) 434 assert.Nil(t, err) 435 436 clientHandle(addr) 437 } 438 439 func gonetRequest(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) { 440 req := helloWorld 441 ctx, _ = codec.EnsureMessage(ctx) 442 reqbytes, err := trpc.DefaultClientCodec.Encode( 443 codec.Message(ctx), 444 req, 445 ) 446 if err != nil { 447 return nil, err 448 } 449 450 cliOpts := getRoundTripOption(opts...) 451 clientTrans := transport.NewClientTransport() 452 rspbytes, err := clientTrans.RoundTrip( 453 ctx, 454 reqbytes, 455 cliOpts..., 456 ) 457 if err != nil { 458 return nil, err 459 } 460 rsp, err := trpc.DefaultClientCodec.Decode( 461 codec.Message(ctx), 462 rspbytes, 463 ) 464 return rsp, err 465 } 466 467 func getRoundTripOption(opts ...transport.RoundTripOption) []transport.RoundTripOption { 468 rtopts := []transport.RoundTripOption{ 469 transport.WithDialNetwork("tcp"), 470 transport.WithClientFramerBuilder(trpc.DefaultFramerBuilder), 471 transport.WithDialTimeout(5 * time.Second), 472 } 473 rtopts = append(rtopts, opts...) 474 return rtopts 475 } 476 477 func getAddr() string { 478 atomic.AddUint64(&port, 1) 479 return "127.0.0.1:" + fmt.Sprint(port) 480 } 481 482 type reuseBufferFramerBuilder struct{} 483 484 func (*reuseBufferFramerBuilder) New(r io.Reader) codec.Framer { 485 return &reuseBufferFramer{r: r, reuseBuffer: make([]byte, len(helloWorld))} 486 } 487 488 type reuseBufferFramer struct { 489 r io.Reader 490 reuseBuffer []byte 491 } 492 493 func (f *reuseBufferFramer) ReadFrame() ([]byte, error) { 494 _, err := io.ReadFull(f.r, f.reuseBuffer) 495 if err != nil { 496 return nil, fmt.Errorf("io.ReadFull err: %w", err) 497 } 498 return f.reuseBuffer, nil 499 } 500 501 type emptyServerCodec struct{} 502 503 func (s *emptyServerCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 504 return reqBuf, nil 505 } 506 507 func (s *emptyServerCodec) Encode(msg codec.Msg, rspBody []byte) ([]byte, error) { 508 return rspBody, nil 509 } 510 511 type emptyClientCodec struct{} 512 513 func (s *emptyClientCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) { 514 return reqBuf, nil 515 } 516 517 func (s *emptyClientCodec) Encode(msg codec.Msg, rspBody []byte) ([]byte, error) { 518 return rspBody, nil 519 }