trpc.group/trpc-go/trpc-go@v1.0.2/transport/server_transport.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport 15 16 import ( 17 "context" 18 "crypto/tls" 19 "errors" 20 "fmt" 21 "io" 22 "net" 23 "os" 24 "runtime" 25 "strconv" 26 "strings" 27 "sync" 28 "syscall" 29 "time" 30 31 "github.com/panjf2000/ants/v2" 32 "trpc.group/trpc-go/trpc-go/internal/reuseport" 33 34 itls "trpc.group/trpc-go/trpc-go/internal/tls" 35 "trpc.group/trpc-go/trpc-go/log" 36 ) 37 38 const transportName = "go-net" 39 40 func init() { 41 RegisterServerTransport(transportName, DefaultServerStreamTransport) 42 } 43 44 const ( 45 // EnvGraceRestart is the flag of graceful restart. 46 EnvGraceRestart = "TRPC_IS_GRACEFUL" 47 48 // EnvGraceFirstFd is the fd of graceful first listener. 49 EnvGraceFirstFd = "TRPC_GRACEFUL_1ST_LISTENFD" 50 51 // EnvGraceRestartFdNum is the number of fd for graceful restart. 52 EnvGraceRestartFdNum = "TRPC_GRACEFUL_LISTENFD_NUM" 53 54 // EnvGraceRestartPPID is the PPID of graceful restart. 55 EnvGraceRestartPPID = "TRPC_GRACEFUL_PPID" 56 ) 57 58 var ( 59 errUnSupportedListenerType = errors.New("not supported listener type") 60 errUnSupportedNetworkType = errors.New("not supported network type") 61 errFileIsNotSocket = errors.New("file is not a socket") 62 ) 63 64 // DefaultServerTransport is the default implementation of ServerStreamTransport. 65 var DefaultServerTransport = NewServerTransport(WithReusePort(true)) 66 67 // NewServerTransport creates a new ServerTransport. 68 func NewServerTransport(opt ...ServerTransportOption) ServerTransport { 69 r := newServerTransport(opt...) 70 return &r 71 } 72 73 // newServerTransport creates a new serverTransport. 74 func newServerTransport(opt ...ServerTransportOption) serverTransport { 75 // this is the default option. 76 opts := defaultServerTransportOptions() 77 for _, o := range opt { 78 o(opts) 79 } 80 addrToConn := make(map[string]*tcpconn) 81 return serverTransport{addrToConn: addrToConn, m: &sync.RWMutex{}, opts: opts} 82 } 83 84 // serverTransport is the implementation details of server transport, may be tcp or udp. 85 type serverTransport struct { 86 addrToConn map[string]*tcpconn 87 m *sync.RWMutex 88 opts *ServerTransportOptions 89 } 90 91 // ListenAndServe starts Listening, returns an error on failure. 92 func (s *serverTransport) ListenAndServe(ctx context.Context, opts ...ListenServeOption) error { 93 lsopts := &ListenServeOptions{} 94 for _, opt := range opts { 95 opt(lsopts) 96 } 97 98 if lsopts.Listener != nil { 99 return s.listenAndServeStream(ctx, lsopts) 100 } 101 // Support simultaneous listening TCP and UDP. 102 networks := strings.Split(lsopts.Network, ",") 103 for _, network := range networks { 104 lsopts.Network = network 105 switch lsopts.Network { 106 case "tcp", "tcp4", "tcp6", "unix": 107 if err := s.listenAndServeStream(ctx, lsopts); err != nil { 108 return err 109 } 110 case "udp", "udp4", "udp6": 111 if err := s.listenAndServePacket(ctx, lsopts); err != nil { 112 return err 113 } 114 default: 115 return fmt.Errorf("server transport: not support network type %s", lsopts.Network) 116 } 117 } 118 return nil 119 } 120 121 // ---------------------------------stream server-----------------------------------------// 122 123 var ( 124 // listenersMap records the listeners in use in the current process. 125 listenersMap = &sync.Map{} 126 // inheritedListenersMap record the listeners inherited from the parent process. 127 // A key(host:port) may have multiple listener fds. 128 inheritedListenersMap = &sync.Map{} 129 // once controls fds passed from parent process to construct listeners. 130 once sync.Once 131 ) 132 133 // GetListenersFds gets listener fds. 134 func GetListenersFds() []*ListenFd { 135 listenersFds := []*ListenFd{} 136 listenersMap.Range(func(key, _ interface{}) bool { 137 var ( 138 fd *ListenFd 139 err error 140 ) 141 142 switch k := key.(type) { 143 case net.Listener: 144 fd, err = getListenerFd(k) 145 case net.PacketConn: 146 fd, err = getPacketConnFd(k) 147 default: 148 log.Errorf("listener type passing not supported, type: %T", key) 149 err = fmt.Errorf("not supported listener type: %T", key) 150 } 151 if err != nil { 152 log.Errorf("cannot get the listener fd, err: %v", err) 153 return true 154 } 155 listenersFds = append(listenersFds, fd) 156 return true 157 }) 158 return listenersFds 159 } 160 161 // SaveListener saves the listener. 162 func SaveListener(listener interface{}) error { 163 switch listener.(type) { 164 case net.Listener, net.PacketConn: 165 listenersMap.Store(listener, struct{}{}) 166 default: 167 return fmt.Errorf("not supported listener type: %T", listener) 168 } 169 return nil 170 } 171 172 // getTCPListener gets the TCP/Unix listener. 173 func (s *serverTransport) getTCPListener(opts *ListenServeOptions) (listener net.Listener, err error) { 174 listener = opts.Listener 175 176 if listener != nil { 177 return listener, nil 178 } 179 180 v, _ := os.LookupEnv(EnvGraceRestart) 181 ok, _ := strconv.ParseBool(v) 182 if ok { 183 // find the passed listener 184 pln, err := getPassedListener(opts.Network, opts.Address) 185 if err != nil { 186 return nil, err 187 } 188 189 listener, ok := pln.(net.Listener) 190 if !ok { 191 return nil, errors.New("invalid net.Listener") 192 } 193 return listener, nil 194 } 195 196 // Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads. 197 if s.opts.ReusePort && opts.Network != "unix" { 198 listener, err = reuseport.Listen(opts.Network, opts.Address) 199 if err != nil { 200 return nil, fmt.Errorf("%s reuseport error:%v", opts.Network, err) 201 } 202 } else { 203 listener, err = net.Listen(opts.Network, opts.Address) 204 if err != nil { 205 return nil, err 206 } 207 } 208 209 return listener, nil 210 } 211 212 // listenAndServeStream starts listening, returns an error on failure. 213 func (s *serverTransport) listenAndServeStream(ctx context.Context, opts *ListenServeOptions) error { 214 if opts.FramerBuilder == nil { 215 return errors.New("tcp transport FramerBuilder empty") 216 } 217 ln, err := s.getTCPListener(opts) 218 if err != nil { 219 return fmt.Errorf("get tcp listener err: %w", err) 220 } 221 // We MUST save the raw TCP listener (instead of (*tls.listener) if TLS is enabled) 222 // to guarantee the underlying fd can be successfully retrieved for hot restart. 223 listenersMap.Store(ln, struct{}{}) 224 ln, err = mayLiftToTLSListener(ln, opts) 225 if err != nil { 226 return fmt.Errorf("may lift to tls listener err: %w", err) 227 } 228 go s.serveStream(ctx, ln, opts) 229 return nil 230 } 231 232 func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listener, error) { 233 if !(len(opts.TLSCertFile) > 0 && len(opts.TLSKeyFile) > 0) { 234 return ln, nil 235 } 236 // Enable TLS. 237 tlsConf, err := itls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile) 238 if err != nil { 239 return nil, fmt.Errorf("tls get server config err: %w", err) 240 } 241 return tls.NewListener(ln, tlsConf), nil 242 } 243 244 func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { 245 return s.serveTCP(ctx, ln, opts) 246 } 247 248 // ---------------------------------packet server-----------------------------------------// 249 250 // listenAndServePacket starts listening, returns an error on failure. 251 func (s *serverTransport) listenAndServePacket(ctx context.Context, opts *ListenServeOptions) error { 252 pool := createUDPRoutinePool(opts.Routines) 253 // Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads. 254 if s.opts.ReusePort { 255 reuseport.ListenerBacklogMaxSize = 4096 256 cores := runtime.NumCPU() 257 for i := 0; i < cores; i++ { 258 udpconn, err := s.getUDPListener(opts) 259 if err != nil { 260 return err 261 } 262 listenersMap.Store(udpconn, struct{}{}) 263 264 go s.servePacket(ctx, udpconn, pool, opts) 265 } 266 } else { 267 udpconn, err := s.getUDPListener(opts) 268 if err != nil { 269 return err 270 } 271 listenersMap.Store(udpconn, struct{}{}) 272 273 go s.servePacket(ctx, udpconn, pool, opts) 274 } 275 return nil 276 } 277 278 // getUDPListener gets UDP listener. 279 func (s *serverTransport) getUDPListener(opts *ListenServeOptions) (udpConn net.PacketConn, err error) { 280 v, _ := os.LookupEnv(EnvGraceRestart) 281 ok, _ := strconv.ParseBool(v) 282 if ok { 283 // Find the passed listener. 284 ln, err := getPassedListener(opts.Network, opts.Address) 285 if err != nil { 286 return nil, err 287 } 288 listener, ok := ln.(net.PacketConn) 289 if !ok { 290 return nil, errors.New("invalid net.PacketConn") 291 } 292 return listener, nil 293 } 294 295 if s.opts.ReusePort { 296 udpConn, err = reuseport.ListenPacket(opts.Network, opts.Address) 297 if err != nil { 298 return nil, fmt.Errorf("udp reuseport error:%v", err) 299 } 300 } else { 301 udpConn, err = net.ListenPacket(opts.Network, opts.Address) 302 if err != nil { 303 return nil, fmt.Errorf("udp listen error:%v", err) 304 } 305 } 306 307 return udpConn, nil 308 } 309 310 func (s *serverTransport) servePacket(ctx context.Context, rwc net.PacketConn, pool *ants.PoolWithFunc, 311 opts *ListenServeOptions) error { 312 switch rwc := rwc.(type) { 313 case *net.UDPConn: 314 return s.serveUDP(ctx, rwc, pool, opts) 315 default: 316 return errors.New("transport not support PacketConn impl") 317 } 318 } 319 320 // ------------------------ tcp/udp connection structures ----------------------------// 321 322 func (s *serverTransport) newConn(ctx context.Context, opts *ListenServeOptions) *conn { 323 idleTimeout := opts.IdleTimeout 324 if s.opts.IdleTimeout > 0 { 325 idleTimeout = s.opts.IdleTimeout 326 } 327 return &conn{ 328 ctx: ctx, 329 handler: opts.Handler, 330 idleTimeout: idleTimeout, 331 } 332 } 333 334 // conn is the struct of connection which is established when server receive a client connecting 335 // request. 336 type conn struct { 337 ctx context.Context 338 cancelCtx context.CancelFunc 339 idleTimeout time.Duration 340 lastVisited time.Time 341 handler Handler 342 } 343 344 func (c *conn) handle(ctx context.Context, req []byte) ([]byte, error) { 345 return c.handler.Handle(ctx, req) 346 } 347 348 func (c *conn) handleClose(ctx context.Context) error { 349 if closeHandler, ok := c.handler.(CloseHandler); ok { 350 return closeHandler.HandleClose(ctx) 351 } 352 return nil 353 } 354 355 var errNotFound = errors.New("listener not found") 356 357 // GetPassedListener gets the inherited listener from parent process by network and address. 358 func GetPassedListener(network, address string) (interface{}, error) { 359 return getPassedListener(network, address) 360 } 361 362 func getPassedListener(network, address string) (interface{}, error) { 363 once.Do(inheritListeners) 364 365 key := network + ":" + address 366 v, ok := inheritedListenersMap.Load(key) 367 if !ok { 368 return nil, errNotFound 369 } 370 371 listeners := v.([]interface{}) 372 if len(listeners) == 0 { 373 return nil, errNotFound 374 } 375 376 ln := listeners[0] 377 listeners = listeners[1:] 378 if len(listeners) == 0 { 379 inheritedListenersMap.Delete(key) 380 } else { 381 inheritedListenersMap.Store(key, listeners) 382 } 383 384 return ln, nil 385 } 386 387 // ListenFd is the listener fd. 388 type ListenFd struct { 389 OriginalListenCloser io.Closer 390 Fd uintptr 391 Name string 392 Network string 393 Address string 394 } 395 396 // inheritListeners stores the listener according to start listenfd and number of listenfd passed 397 // by environment variables. 398 func inheritListeners() { 399 firstListenFd, err := strconv.ParseUint(os.Getenv(EnvGraceFirstFd), 10, 32) 400 if err != nil { 401 log.Errorf("invalid %s, error: %v", EnvGraceFirstFd, err) 402 } 403 404 num, err := strconv.ParseUint(os.Getenv(EnvGraceRestartFdNum), 10, 32) 405 if err != nil { 406 log.Errorf("invalid %s, error: %v", EnvGraceRestartFdNum, err) 407 } 408 409 for fd := firstListenFd; fd < firstListenFd+num; fd++ { 410 file := os.NewFile(uintptr(fd), "") 411 listener, addr, err := fileListener(file) 412 file.Close() 413 if err != nil { 414 log.Errorf("get file listener error: %v", err) 415 continue 416 } 417 418 key := addr.Network() + ":" + addr.String() 419 v, ok := inheritedListenersMap.LoadOrStore(key, []interface{}{listener}) 420 if ok { 421 listeners := v.([]interface{}) 422 listeners = append(listeners, listener) 423 inheritedListenersMap.Store(key, listeners) 424 } 425 } 426 } 427 428 func fileListener(file *os.File) (interface{}, net.Addr, error) { 429 // Check file status. 430 fin, err := file.Stat() 431 if err != nil { 432 return nil, nil, err 433 } 434 435 // Is this a socket fd. 436 if fin.Mode()&os.ModeSocket == 0 { 437 return nil, nil, errFileIsNotSocket 438 } 439 440 // tcp, tcp4 or tcp6. 441 if listener, err := net.FileListener(file); err == nil { 442 return listener, listener.Addr(), nil 443 } 444 445 // udp, udp4 or udp6. 446 if packetConn, err := net.FilePacketConn(file); err == nil { 447 return packetConn, packetConn.LocalAddr(), nil 448 } 449 450 return nil, nil, errUnSupportedNetworkType 451 } 452 453 func getPacketConnFd(c net.PacketConn) (*ListenFd, error) { 454 sc, ok := c.(syscall.Conn) 455 if !ok { 456 return nil, fmt.Errorf("getPacketConnFd err: %w", errUnSupportedListenerType) 457 } 458 lnFd, err := getRawFd(sc) 459 if err != nil { 460 return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err) 461 } 462 return &ListenFd{ 463 OriginalListenCloser: c, 464 Fd: lnFd, 465 Name: "a udp listener fd", 466 Network: c.LocalAddr().Network(), 467 Address: c.LocalAddr().String(), 468 }, nil 469 } 470 471 func getListenerFd(ln net.Listener) (*ListenFd, error) { 472 sc, ok := ln.(syscall.Conn) 473 if !ok { 474 return nil, fmt.Errorf("getListenerFd err: %w", errUnSupportedListenerType) 475 } 476 fd, err := getRawFd(sc) 477 if err != nil { 478 return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err) 479 } 480 return &ListenFd{ 481 OriginalListenCloser: ln, 482 Fd: fd, 483 Name: "a tcp listener fd", 484 Network: ln.Addr().Network(), 485 Address: ln.Addr().String(), 486 }, nil 487 } 488 489 // getRawFd acts like: 490 // 491 // func (ln *net.TCPListener) (uintptr, error) { 492 // f, err := ln.File() 493 // if err != nil { 494 // return 0, err 495 // } 496 // fd, err := f.Fd() 497 // if err != nil { 498 // return 0, err 499 // } 500 // } 501 // 502 // But it differs in an important way: 503 // 504 // The method (*os.File).Fd() will set the original file descriptor to blocking mode as a side effect of fcntl(), 505 // which will lead to indefinite hangs of Close/Read/Write, etc. 506 // 507 // References: 508 // - https://github.com/golang/go/issues/29277 509 // - https://github.com/golang/go/issues/29277#issuecomment-447526159 510 // - https://github.com/golang/go/issues/29277#issuecomment-448117332 511 // - https://github.com/golang/go/issues/43894 512 func getRawFd(sc syscall.Conn) (uintptr, error) { 513 c, err := sc.SyscallConn() 514 if err != nil { 515 return 0, fmt.Errorf("sc.SyscallConn err: %w", err) 516 } 517 var lnFd uintptr 518 if err := c.Control(func(fd uintptr) { 519 lnFd = fd 520 }); err != nil { 521 return 0, fmt.Errorf("c.Control err: %w", err) 522 } 523 return lnFd, nil 524 }