trpc.group/trpc-go/trpc-go@v1.0.3/transport/server_transport.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport 15 16 import ( 17 "context" 18 "crypto/tls" 19 "errors" 20 "fmt" 21 "net" 22 "os" 23 "runtime" 24 "strconv" 25 "strings" 26 "sync" 27 "syscall" 28 "time" 29 30 "github.com/panjf2000/ants/v2" 31 "trpc.group/trpc-go/trpc-go/internal/reuseport" 32 33 itls "trpc.group/trpc-go/trpc-go/internal/tls" 34 "trpc.group/trpc-go/trpc-go/log" 35 ) 36 37 const transportName = "go-net" 38 39 func init() { 40 RegisterServerTransport(transportName, DefaultServerStreamTransport) 41 } 42 43 const ( 44 // EnvGraceRestart is the flag of graceful restart. 45 EnvGraceRestart = "TRPC_IS_GRACEFUL" 46 47 // EnvGraceFirstFd is the fd of graceful first listener. 48 EnvGraceFirstFd = "TRPC_GRACEFUL_1ST_LISTENFD" 49 50 // EnvGraceRestartFdNum is the number of fd for graceful restart. 51 EnvGraceRestartFdNum = "TRPC_GRACEFUL_LISTENFD_NUM" 52 53 // EnvGraceRestartPPID is the PPID of graceful restart. 54 EnvGraceRestartPPID = "TRPC_GRACEFUL_PPID" 55 ) 56 57 var ( 58 errUnSupportedListenerType = errors.New("not supported listener type") 59 errUnSupportedNetworkType = errors.New("not supported network type") 60 errFileIsNotSocket = errors.New("file is not a socket") 61 ) 62 63 // DefaultServerTransport is the default implementation of ServerStreamTransport. 64 var DefaultServerTransport = NewServerTransport(WithReusePort(true)) 65 66 // NewServerTransport creates a new ServerTransport. 67 func NewServerTransport(opt ...ServerTransportOption) ServerTransport { 68 r := newServerTransport(opt...) 69 return &r 70 } 71 72 // newServerTransport creates a new serverTransport. 73 func newServerTransport(opt ...ServerTransportOption) serverTransport { 74 // this is the default option. 75 opts := defaultServerTransportOptions() 76 for _, o := range opt { 77 o(opts) 78 } 79 addrToConn := make(map[string]*tcpconn) 80 return serverTransport{addrToConn: addrToConn, m: &sync.RWMutex{}, opts: opts} 81 } 82 83 // serverTransport is the implementation details of server transport, may be tcp or udp. 84 type serverTransport struct { 85 addrToConn map[string]*tcpconn 86 m *sync.RWMutex 87 opts *ServerTransportOptions 88 } 89 90 // ListenAndServe starts Listening, returns an error on failure. 91 func (s *serverTransport) ListenAndServe(ctx context.Context, opts ...ListenServeOption) error { 92 lsopts := &ListenServeOptions{} 93 for _, opt := range opts { 94 opt(lsopts) 95 } 96 97 if lsopts.Listener != nil { 98 return s.listenAndServeStream(ctx, lsopts) 99 } 100 // Support simultaneous listening TCP and UDP. 101 networks := strings.Split(lsopts.Network, ",") 102 for _, network := range networks { 103 lsopts.Network = network 104 switch lsopts.Network { 105 case "tcp", "tcp4", "tcp6", "unix": 106 if err := s.listenAndServeStream(ctx, lsopts); err != nil { 107 return err 108 } 109 case "udp", "udp4", "udp6": 110 if err := s.listenAndServePacket(ctx, lsopts); err != nil { 111 return err 112 } 113 default: 114 return fmt.Errorf("server transport: not support network type %s", lsopts.Network) 115 } 116 } 117 return nil 118 } 119 120 // ---------------------------------stream server-----------------------------------------// 121 122 var ( 123 // listenersMap records the listeners in use in the current process. 124 listenersMap = &sync.Map{} 125 // inheritedListenersMap record the listeners inherited from the parent process. 126 // A key(host:port) may have multiple listener fds. 127 inheritedListenersMap = &sync.Map{} 128 // once controls fds passed from parent process to construct listeners. 129 once sync.Once 130 ) 131 132 // GetListenersFds gets listener fds. 133 func GetListenersFds() []*ListenFd { 134 listenersFds := []*ListenFd{} 135 listenersMap.Range(func(key, _ interface{}) bool { 136 var ( 137 fd *ListenFd 138 err error 139 ) 140 141 switch k := key.(type) { 142 case net.Listener: 143 fd, err = getListenerFd(k) 144 case net.PacketConn: 145 fd, err = getPacketConnFd(k) 146 default: 147 log.Errorf("listener type passing not supported, type: %T", key) 148 err = fmt.Errorf("not supported listener type: %T", key) 149 } 150 if err != nil { 151 log.Errorf("cannot get the listener fd, err: %v", err) 152 return true 153 } 154 listenersFds = append(listenersFds, fd) 155 return true 156 }) 157 return listenersFds 158 } 159 160 // SaveListener saves the listener. 161 func SaveListener(listener interface{}) error { 162 switch listener.(type) { 163 case net.Listener, net.PacketConn: 164 listenersMap.Store(listener, struct{}{}) 165 default: 166 return fmt.Errorf("not supported listener type: %T", listener) 167 } 168 return nil 169 } 170 171 // getTCPListener gets the TCP/Unix listener. 172 func (s *serverTransport) getTCPListener(opts *ListenServeOptions) (listener net.Listener, err error) { 173 listener = opts.Listener 174 175 if listener != nil { 176 return listener, nil 177 } 178 179 v, _ := os.LookupEnv(EnvGraceRestart) 180 ok, _ := strconv.ParseBool(v) 181 if ok { 182 // find the passed listener 183 pln, err := getPassedListener(opts.Network, opts.Address) 184 if err != nil { 185 return nil, err 186 } 187 188 listener, ok := pln.(net.Listener) 189 if !ok { 190 return nil, errors.New("invalid net.Listener") 191 } 192 return listener, nil 193 } 194 195 // Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads. 196 if s.opts.ReusePort && opts.Network != "unix" { 197 listener, err = reuseport.Listen(opts.Network, opts.Address) 198 if err != nil { 199 return nil, fmt.Errorf("%s reuseport error:%v", opts.Network, err) 200 } 201 } else { 202 listener, err = net.Listen(opts.Network, opts.Address) 203 if err != nil { 204 return nil, err 205 } 206 } 207 208 return listener, nil 209 } 210 211 // listenAndServeStream starts listening, returns an error on failure. 212 func (s *serverTransport) listenAndServeStream(ctx context.Context, opts *ListenServeOptions) error { 213 if opts.FramerBuilder == nil { 214 return errors.New("tcp transport FramerBuilder empty") 215 } 216 ln, err := s.getTCPListener(opts) 217 if err != nil { 218 return fmt.Errorf("get tcp listener err: %w", err) 219 } 220 // We MUST save the raw TCP listener (instead of (*tls.listener) if TLS is enabled) 221 // to guarantee the underlying fd can be successfully retrieved for hot restart. 222 listenersMap.Store(ln, struct{}{}) 223 ln, err = mayLiftToTLSListener(ln, opts) 224 if err != nil { 225 return fmt.Errorf("may lift to tls listener err: %w", err) 226 } 227 go s.serveStream(ctx, ln, opts) 228 return nil 229 } 230 231 func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listener, error) { 232 if !(len(opts.TLSCertFile) > 0 && len(opts.TLSKeyFile) > 0) { 233 return ln, nil 234 } 235 // Enable TLS. 236 tlsConf, err := itls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile) 237 if err != nil { 238 return nil, fmt.Errorf("tls get server config err: %w", err) 239 } 240 return tls.NewListener(ln, tlsConf), nil 241 } 242 243 func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { 244 var once sync.Once 245 closeListener := func() { ln.Close() } 246 defer once.Do(closeListener) 247 // Create a goroutine to watch ctx.Done() channel. 248 // Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection. 249 go func() { 250 select { 251 case <-ctx.Done(): 252 // ctx.Done will perform the following two actions: 253 // 1. Stop listening. 254 // 2. Cancel all currently established connections. 255 // Whereas opts.StopListening will only stop listening. 256 case <-opts.StopListening: 257 } 258 log.Tracef("recv server close event") 259 once.Do(closeListener) 260 }() 261 return s.serveTCP(ctx, ln, opts) 262 } 263 264 // ---------------------------------packet server-----------------------------------------// 265 266 // listenAndServePacket starts listening, returns an error on failure. 267 func (s *serverTransport) listenAndServePacket(ctx context.Context, opts *ListenServeOptions) error { 268 pool := createUDPRoutinePool(opts.Routines) 269 // Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads. 270 if s.opts.ReusePort { 271 reuseport.ListenerBacklogMaxSize = 4096 272 cores := runtime.NumCPU() 273 for i := 0; i < cores; i++ { 274 udpconn, err := s.getUDPListener(opts) 275 if err != nil { 276 return err 277 } 278 listenersMap.Store(udpconn, struct{}{}) 279 280 go s.servePacket(ctx, udpconn, pool, opts) 281 } 282 } else { 283 udpconn, err := s.getUDPListener(opts) 284 if err != nil { 285 return err 286 } 287 listenersMap.Store(udpconn, struct{}{}) 288 289 go s.servePacket(ctx, udpconn, pool, opts) 290 } 291 return nil 292 } 293 294 // getUDPListener gets UDP listener. 295 func (s *serverTransport) getUDPListener(opts *ListenServeOptions) (udpConn net.PacketConn, err error) { 296 v, _ := os.LookupEnv(EnvGraceRestart) 297 ok, _ := strconv.ParseBool(v) 298 if ok { 299 // Find the passed listener. 300 ln, err := getPassedListener(opts.Network, opts.Address) 301 if err != nil { 302 return nil, err 303 } 304 listener, ok := ln.(net.PacketConn) 305 if !ok { 306 return nil, errors.New("invalid net.PacketConn") 307 } 308 return listener, nil 309 } 310 311 if s.opts.ReusePort { 312 udpConn, err = reuseport.ListenPacket(opts.Network, opts.Address) 313 if err != nil { 314 return nil, fmt.Errorf("udp reuseport error:%v", err) 315 } 316 } else { 317 udpConn, err = net.ListenPacket(opts.Network, opts.Address) 318 if err != nil { 319 return nil, fmt.Errorf("udp listen error:%v", err) 320 } 321 } 322 323 return udpConn, nil 324 } 325 326 func (s *serverTransport) servePacket(ctx context.Context, rwc net.PacketConn, pool *ants.PoolWithFunc, 327 opts *ListenServeOptions) error { 328 switch rwc := rwc.(type) { 329 case *net.UDPConn: 330 return s.serveUDP(ctx, rwc, pool, opts) 331 default: 332 return errors.New("transport not support PacketConn impl") 333 } 334 } 335 336 // ------------------------ tcp/udp connection structures ----------------------------// 337 338 func (s *serverTransport) newConn(ctx context.Context, opts *ListenServeOptions) *conn { 339 idleTimeout := opts.IdleTimeout 340 if s.opts.IdleTimeout > 0 { 341 idleTimeout = s.opts.IdleTimeout 342 } 343 return &conn{ 344 ctx: ctx, 345 handler: opts.Handler, 346 idleTimeout: idleTimeout, 347 } 348 } 349 350 // conn is the struct of connection which is established when server receive a client connecting 351 // request. 352 type conn struct { 353 ctx context.Context 354 cancelCtx context.CancelFunc 355 idleTimeout time.Duration 356 lastVisited time.Time 357 handler Handler 358 } 359 360 func (c *conn) handle(ctx context.Context, req []byte) ([]byte, error) { 361 return c.handler.Handle(ctx, req) 362 } 363 364 func (c *conn) handleClose(ctx context.Context) error { 365 if closeHandler, ok := c.handler.(CloseHandler); ok { 366 return closeHandler.HandleClose(ctx) 367 } 368 return nil 369 } 370 371 var errNotFound = errors.New("listener not found") 372 373 // GetPassedListener gets the inherited listener from parent process by network and address. 374 func GetPassedListener(network, address string) (interface{}, error) { 375 return getPassedListener(network, address) 376 } 377 378 func getPassedListener(network, address string) (interface{}, error) { 379 once.Do(inheritListeners) 380 381 key := network + ":" + address 382 v, ok := inheritedListenersMap.Load(key) 383 if !ok { 384 return nil, errNotFound 385 } 386 387 listeners := v.([]interface{}) 388 if len(listeners) == 0 { 389 return nil, errNotFound 390 } 391 392 ln := listeners[0] 393 listeners = listeners[1:] 394 if len(listeners) == 0 { 395 inheritedListenersMap.Delete(key) 396 } else { 397 inheritedListenersMap.Store(key, listeners) 398 } 399 400 return ln, nil 401 } 402 403 // ListenFd is the listener fd. 404 type ListenFd struct { 405 Fd uintptr 406 Name string 407 Network string 408 Address string 409 } 410 411 // inheritListeners stores the listener according to start listenfd and number of listenfd passed 412 // by environment variables. 413 func inheritListeners() { 414 firstListenFd, err := strconv.ParseUint(os.Getenv(EnvGraceFirstFd), 10, 32) 415 if err != nil { 416 log.Errorf("invalid %s, error: %v", EnvGraceFirstFd, err) 417 } 418 419 num, err := strconv.ParseUint(os.Getenv(EnvGraceRestartFdNum), 10, 32) 420 if err != nil { 421 log.Errorf("invalid %s, error: %v", EnvGraceRestartFdNum, err) 422 } 423 424 for fd := firstListenFd; fd < firstListenFd+num; fd++ { 425 file := os.NewFile(uintptr(fd), "") 426 listener, addr, err := fileListener(file) 427 file.Close() 428 if err != nil { 429 log.Errorf("get file listener error: %v", err) 430 continue 431 } 432 433 key := addr.Network() + ":" + addr.String() 434 v, ok := inheritedListenersMap.LoadOrStore(key, []interface{}{listener}) 435 if ok { 436 listeners := v.([]interface{}) 437 listeners = append(listeners, listener) 438 inheritedListenersMap.Store(key, listeners) 439 } 440 } 441 } 442 443 func fileListener(file *os.File) (interface{}, net.Addr, error) { 444 // Check file status. 445 fin, err := file.Stat() 446 if err != nil { 447 return nil, nil, err 448 } 449 450 // Is this a socket fd. 451 if fin.Mode()&os.ModeSocket == 0 { 452 return nil, nil, errFileIsNotSocket 453 } 454 455 // tcp, tcp4 or tcp6. 456 if listener, err := net.FileListener(file); err == nil { 457 return listener, listener.Addr(), nil 458 } 459 460 // udp, udp4 or udp6. 461 if packetConn, err := net.FilePacketConn(file); err == nil { 462 return packetConn, packetConn.LocalAddr(), nil 463 } 464 465 return nil, nil, errUnSupportedNetworkType 466 } 467 468 func getPacketConnFd(c net.PacketConn) (*ListenFd, error) { 469 sc, ok := c.(syscall.Conn) 470 if !ok { 471 return nil, fmt.Errorf("getPacketConnFd err: %w", errUnSupportedListenerType) 472 } 473 lnFd, err := getRawFd(sc) 474 if err != nil { 475 return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err) 476 } 477 return &ListenFd{ 478 Fd: lnFd, 479 Name: "a udp listener fd", 480 Network: c.LocalAddr().Network(), 481 Address: c.LocalAddr().String(), 482 }, nil 483 } 484 485 func getListenerFd(ln net.Listener) (*ListenFd, error) { 486 sc, ok := ln.(syscall.Conn) 487 if !ok { 488 return nil, fmt.Errorf("getListenerFd err: %w", errUnSupportedListenerType) 489 } 490 fd, err := getRawFd(sc) 491 if err != nil { 492 return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err) 493 } 494 return &ListenFd{ 495 Fd: fd, 496 Name: "a tcp listener fd", 497 Network: ln.Addr().Network(), 498 Address: ln.Addr().String(), 499 }, nil 500 } 501 502 // getRawFd acts like: 503 // 504 // func (ln *net.TCPListener) (uintptr, error) { 505 // f, err := ln.File() 506 // if err != nil { 507 // return 0, err 508 // } 509 // fd, err := f.Fd() 510 // if err != nil { 511 // return 0, err 512 // } 513 // } 514 // 515 // But it differs in an important way: 516 // 517 // The method (*os.File).Fd() will set the original file descriptor to blocking mode as a side effect of fcntl(), 518 // which will lead to indefinite hangs of Close/Read/Write, etc. 519 // 520 // References: 521 // - https://github.com/golang/go/issues/29277 522 // - https://github.com/golang/go/issues/29277#issuecomment-447526159 523 // - https://github.com/golang/go/issues/29277#issuecomment-448117332 524 // - https://github.com/golang/go/issues/43894 525 func getRawFd(sc syscall.Conn) (uintptr, error) { 526 c, err := sc.SyscallConn() 527 if err != nil { 528 return 0, fmt.Errorf("sc.SyscallConn err: %w", err) 529 } 530 var lnFd uintptr 531 if err := c.Control(func(fd uintptr) { 532 lnFd = fd 533 }); err != nil { 534 return 0, fmt.Errorf("c.Control err: %w", err) 535 } 536 return lnFd, nil 537 }