trpc.group/trpc-go/trpc-go@v1.0.3/transport/server_transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"errors"
    20  	"fmt"
    21  	"net"
    22  	"os"
    23  	"runtime"
    24  	"strconv"
    25  	"strings"
    26  	"sync"
    27  	"syscall"
    28  	"time"
    29  
    30  	"github.com/panjf2000/ants/v2"
    31  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    32  
    33  	itls "trpc.group/trpc-go/trpc-go/internal/tls"
    34  	"trpc.group/trpc-go/trpc-go/log"
    35  )
    36  
    37  const transportName = "go-net"
    38  
    39  func init() {
    40  	RegisterServerTransport(transportName, DefaultServerStreamTransport)
    41  }
    42  
    43  const (
    44  	// EnvGraceRestart is the flag of graceful restart.
    45  	EnvGraceRestart = "TRPC_IS_GRACEFUL"
    46  
    47  	// EnvGraceFirstFd is the fd of graceful first listener.
    48  	EnvGraceFirstFd = "TRPC_GRACEFUL_1ST_LISTENFD"
    49  
    50  	// EnvGraceRestartFdNum is the number of fd for graceful restart.
    51  	EnvGraceRestartFdNum = "TRPC_GRACEFUL_LISTENFD_NUM"
    52  
    53  	// EnvGraceRestartPPID is the PPID of graceful restart.
    54  	EnvGraceRestartPPID = "TRPC_GRACEFUL_PPID"
    55  )
    56  
    57  var (
    58  	errUnSupportedListenerType = errors.New("not supported listener type")
    59  	errUnSupportedNetworkType  = errors.New("not supported network type")
    60  	errFileIsNotSocket         = errors.New("file is not a socket")
    61  )
    62  
    63  // DefaultServerTransport is the default implementation of ServerStreamTransport.
    64  var DefaultServerTransport = NewServerTransport(WithReusePort(true))
    65  
    66  // NewServerTransport creates a new ServerTransport.
    67  func NewServerTransport(opt ...ServerTransportOption) ServerTransport {
    68  	r := newServerTransport(opt...)
    69  	return &r
    70  }
    71  
    72  // newServerTransport creates a new serverTransport.
    73  func newServerTransport(opt ...ServerTransportOption) serverTransport {
    74  	// this is the default option.
    75  	opts := defaultServerTransportOptions()
    76  	for _, o := range opt {
    77  		o(opts)
    78  	}
    79  	addrToConn := make(map[string]*tcpconn)
    80  	return serverTransport{addrToConn: addrToConn, m: &sync.RWMutex{}, opts: opts}
    81  }
    82  
    83  // serverTransport is the implementation details of server transport, may be tcp or udp.
    84  type serverTransport struct {
    85  	addrToConn map[string]*tcpconn
    86  	m          *sync.RWMutex
    87  	opts       *ServerTransportOptions
    88  }
    89  
    90  // ListenAndServe starts Listening, returns an error on failure.
    91  func (s *serverTransport) ListenAndServe(ctx context.Context, opts ...ListenServeOption) error {
    92  	lsopts := &ListenServeOptions{}
    93  	for _, opt := range opts {
    94  		opt(lsopts)
    95  	}
    96  
    97  	if lsopts.Listener != nil {
    98  		return s.listenAndServeStream(ctx, lsopts)
    99  	}
   100  	// Support simultaneous listening TCP and UDP.
   101  	networks := strings.Split(lsopts.Network, ",")
   102  	for _, network := range networks {
   103  		lsopts.Network = network
   104  		switch lsopts.Network {
   105  		case "tcp", "tcp4", "tcp6", "unix":
   106  			if err := s.listenAndServeStream(ctx, lsopts); err != nil {
   107  				return err
   108  			}
   109  		case "udp", "udp4", "udp6":
   110  			if err := s.listenAndServePacket(ctx, lsopts); err != nil {
   111  				return err
   112  			}
   113  		default:
   114  			return fmt.Errorf("server transport: not support network type %s", lsopts.Network)
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  // ---------------------------------stream server-----------------------------------------//
   121  
   122  var (
   123  	// listenersMap records the listeners in use in the current process.
   124  	listenersMap = &sync.Map{}
   125  	// inheritedListenersMap record the listeners inherited from the parent process.
   126  	// A key(host:port) may have multiple listener fds.
   127  	inheritedListenersMap = &sync.Map{}
   128  	// once controls fds passed from parent process to construct listeners.
   129  	once sync.Once
   130  )
   131  
   132  // GetListenersFds gets listener fds.
   133  func GetListenersFds() []*ListenFd {
   134  	listenersFds := []*ListenFd{}
   135  	listenersMap.Range(func(key, _ interface{}) bool {
   136  		var (
   137  			fd  *ListenFd
   138  			err error
   139  		)
   140  
   141  		switch k := key.(type) {
   142  		case net.Listener:
   143  			fd, err = getListenerFd(k)
   144  		case net.PacketConn:
   145  			fd, err = getPacketConnFd(k)
   146  		default:
   147  			log.Errorf("listener type passing not supported, type: %T", key)
   148  			err = fmt.Errorf("not supported listener type: %T", key)
   149  		}
   150  		if err != nil {
   151  			log.Errorf("cannot get the listener fd, err: %v", err)
   152  			return true
   153  		}
   154  		listenersFds = append(listenersFds, fd)
   155  		return true
   156  	})
   157  	return listenersFds
   158  }
   159  
   160  // SaveListener saves the listener.
   161  func SaveListener(listener interface{}) error {
   162  	switch listener.(type) {
   163  	case net.Listener, net.PacketConn:
   164  		listenersMap.Store(listener, struct{}{})
   165  	default:
   166  		return fmt.Errorf("not supported listener type: %T", listener)
   167  	}
   168  	return nil
   169  }
   170  
   171  // getTCPListener gets the TCP/Unix listener.
   172  func (s *serverTransport) getTCPListener(opts *ListenServeOptions) (listener net.Listener, err error) {
   173  	listener = opts.Listener
   174  
   175  	if listener != nil {
   176  		return listener, nil
   177  	}
   178  
   179  	v, _ := os.LookupEnv(EnvGraceRestart)
   180  	ok, _ := strconv.ParseBool(v)
   181  	if ok {
   182  		// find the passed listener
   183  		pln, err := getPassedListener(opts.Network, opts.Address)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  
   188  		listener, ok := pln.(net.Listener)
   189  		if !ok {
   190  			return nil, errors.New("invalid net.Listener")
   191  		}
   192  		return listener, nil
   193  	}
   194  
   195  	// Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads.
   196  	if s.opts.ReusePort && opts.Network != "unix" {
   197  		listener, err = reuseport.Listen(opts.Network, opts.Address)
   198  		if err != nil {
   199  			return nil, fmt.Errorf("%s reuseport error:%v", opts.Network, err)
   200  		}
   201  	} else {
   202  		listener, err = net.Listen(opts.Network, opts.Address)
   203  		if err != nil {
   204  			return nil, err
   205  		}
   206  	}
   207  
   208  	return listener, nil
   209  }
   210  
   211  // listenAndServeStream starts listening, returns an error on failure.
   212  func (s *serverTransport) listenAndServeStream(ctx context.Context, opts *ListenServeOptions) error {
   213  	if opts.FramerBuilder == nil {
   214  		return errors.New("tcp transport FramerBuilder empty")
   215  	}
   216  	ln, err := s.getTCPListener(opts)
   217  	if err != nil {
   218  		return fmt.Errorf("get tcp listener err: %w", err)
   219  	}
   220  	// We MUST save the raw TCP listener (instead of (*tls.listener) if TLS is enabled)
   221  	// to guarantee the underlying fd can be successfully retrieved for hot restart.
   222  	listenersMap.Store(ln, struct{}{})
   223  	ln, err = mayLiftToTLSListener(ln, opts)
   224  	if err != nil {
   225  		return fmt.Errorf("may lift to tls listener err: %w", err)
   226  	}
   227  	go s.serveStream(ctx, ln, opts)
   228  	return nil
   229  }
   230  
   231  func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listener, error) {
   232  	if !(len(opts.TLSCertFile) > 0 && len(opts.TLSKeyFile) > 0) {
   233  		return ln, nil
   234  	}
   235  	// Enable TLS.
   236  	tlsConf, err := itls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   237  	if err != nil {
   238  		return nil, fmt.Errorf("tls get server config err: %w", err)
   239  	}
   240  	return tls.NewListener(ln, tlsConf), nil
   241  }
   242  
   243  func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error {
   244  	var once sync.Once
   245  	closeListener := func() { ln.Close() }
   246  	defer once.Do(closeListener)
   247  	// Create a goroutine to watch ctx.Done() channel.
   248  	// Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection.
   249  	go func() {
   250  		select {
   251  		case <-ctx.Done():
   252  		// ctx.Done will perform the following two actions:
   253  		// 1. Stop listening.
   254  		// 2. Cancel all currently established connections.
   255  		// Whereas opts.StopListening will only stop listening.
   256  		case <-opts.StopListening:
   257  		}
   258  		log.Tracef("recv server close event")
   259  		once.Do(closeListener)
   260  	}()
   261  	return s.serveTCP(ctx, ln, opts)
   262  }
   263  
   264  // ---------------------------------packet server-----------------------------------------//
   265  
   266  // listenAndServePacket starts listening, returns an error on failure.
   267  func (s *serverTransport) listenAndServePacket(ctx context.Context, opts *ListenServeOptions) error {
   268  	pool := createUDPRoutinePool(opts.Routines)
   269  	// Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads.
   270  	if s.opts.ReusePort {
   271  		reuseport.ListenerBacklogMaxSize = 4096
   272  		cores := runtime.NumCPU()
   273  		for i := 0; i < cores; i++ {
   274  			udpconn, err := s.getUDPListener(opts)
   275  			if err != nil {
   276  				return err
   277  			}
   278  			listenersMap.Store(udpconn, struct{}{})
   279  
   280  			go s.servePacket(ctx, udpconn, pool, opts)
   281  		}
   282  	} else {
   283  		udpconn, err := s.getUDPListener(opts)
   284  		if err != nil {
   285  			return err
   286  		}
   287  		listenersMap.Store(udpconn, struct{}{})
   288  
   289  		go s.servePacket(ctx, udpconn, pool, opts)
   290  	}
   291  	return nil
   292  }
   293  
   294  // getUDPListener gets UDP listener.
   295  func (s *serverTransport) getUDPListener(opts *ListenServeOptions) (udpConn net.PacketConn, err error) {
   296  	v, _ := os.LookupEnv(EnvGraceRestart)
   297  	ok, _ := strconv.ParseBool(v)
   298  	if ok {
   299  		// Find the passed listener.
   300  		ln, err := getPassedListener(opts.Network, opts.Address)
   301  		if err != nil {
   302  			return nil, err
   303  		}
   304  		listener, ok := ln.(net.PacketConn)
   305  		if !ok {
   306  			return nil, errors.New("invalid net.PacketConn")
   307  		}
   308  		return listener, nil
   309  	}
   310  
   311  	if s.opts.ReusePort {
   312  		udpConn, err = reuseport.ListenPacket(opts.Network, opts.Address)
   313  		if err != nil {
   314  			return nil, fmt.Errorf("udp reuseport error:%v", err)
   315  		}
   316  	} else {
   317  		udpConn, err = net.ListenPacket(opts.Network, opts.Address)
   318  		if err != nil {
   319  			return nil, fmt.Errorf("udp listen error:%v", err)
   320  		}
   321  	}
   322  
   323  	return udpConn, nil
   324  }
   325  
   326  func (s *serverTransport) servePacket(ctx context.Context, rwc net.PacketConn, pool *ants.PoolWithFunc,
   327  	opts *ListenServeOptions) error {
   328  	switch rwc := rwc.(type) {
   329  	case *net.UDPConn:
   330  		return s.serveUDP(ctx, rwc, pool, opts)
   331  	default:
   332  		return errors.New("transport not support PacketConn impl")
   333  	}
   334  }
   335  
   336  // ------------------------ tcp/udp connection structures ----------------------------//
   337  
   338  func (s *serverTransport) newConn(ctx context.Context, opts *ListenServeOptions) *conn {
   339  	idleTimeout := opts.IdleTimeout
   340  	if s.opts.IdleTimeout > 0 {
   341  		idleTimeout = s.opts.IdleTimeout
   342  	}
   343  	return &conn{
   344  		ctx:         ctx,
   345  		handler:     opts.Handler,
   346  		idleTimeout: idleTimeout,
   347  	}
   348  }
   349  
   350  // conn is the struct of connection which is established when server receive a client connecting
   351  // request.
   352  type conn struct {
   353  	ctx         context.Context
   354  	cancelCtx   context.CancelFunc
   355  	idleTimeout time.Duration
   356  	lastVisited time.Time
   357  	handler     Handler
   358  }
   359  
   360  func (c *conn) handle(ctx context.Context, req []byte) ([]byte, error) {
   361  	return c.handler.Handle(ctx, req)
   362  }
   363  
   364  func (c *conn) handleClose(ctx context.Context) error {
   365  	if closeHandler, ok := c.handler.(CloseHandler); ok {
   366  		return closeHandler.HandleClose(ctx)
   367  	}
   368  	return nil
   369  }
   370  
   371  var errNotFound = errors.New("listener not found")
   372  
   373  // GetPassedListener gets the inherited listener from parent process by network and address.
   374  func GetPassedListener(network, address string) (interface{}, error) {
   375  	return getPassedListener(network, address)
   376  }
   377  
   378  func getPassedListener(network, address string) (interface{}, error) {
   379  	once.Do(inheritListeners)
   380  
   381  	key := network + ":" + address
   382  	v, ok := inheritedListenersMap.Load(key)
   383  	if !ok {
   384  		return nil, errNotFound
   385  	}
   386  
   387  	listeners := v.([]interface{})
   388  	if len(listeners) == 0 {
   389  		return nil, errNotFound
   390  	}
   391  
   392  	ln := listeners[0]
   393  	listeners = listeners[1:]
   394  	if len(listeners) == 0 {
   395  		inheritedListenersMap.Delete(key)
   396  	} else {
   397  		inheritedListenersMap.Store(key, listeners)
   398  	}
   399  
   400  	return ln, nil
   401  }
   402  
   403  // ListenFd is the listener fd.
   404  type ListenFd struct {
   405  	Fd      uintptr
   406  	Name    string
   407  	Network string
   408  	Address string
   409  }
   410  
   411  // inheritListeners stores the listener according to start listenfd and number of listenfd passed
   412  // by environment variables.
   413  func inheritListeners() {
   414  	firstListenFd, err := strconv.ParseUint(os.Getenv(EnvGraceFirstFd), 10, 32)
   415  	if err != nil {
   416  		log.Errorf("invalid %s, error: %v", EnvGraceFirstFd, err)
   417  	}
   418  
   419  	num, err := strconv.ParseUint(os.Getenv(EnvGraceRestartFdNum), 10, 32)
   420  	if err != nil {
   421  		log.Errorf("invalid %s, error: %v", EnvGraceRestartFdNum, err)
   422  	}
   423  
   424  	for fd := firstListenFd; fd < firstListenFd+num; fd++ {
   425  		file := os.NewFile(uintptr(fd), "")
   426  		listener, addr, err := fileListener(file)
   427  		file.Close()
   428  		if err != nil {
   429  			log.Errorf("get file listener error: %v", err)
   430  			continue
   431  		}
   432  
   433  		key := addr.Network() + ":" + addr.String()
   434  		v, ok := inheritedListenersMap.LoadOrStore(key, []interface{}{listener})
   435  		if ok {
   436  			listeners := v.([]interface{})
   437  			listeners = append(listeners, listener)
   438  			inheritedListenersMap.Store(key, listeners)
   439  		}
   440  	}
   441  }
   442  
   443  func fileListener(file *os.File) (interface{}, net.Addr, error) {
   444  	// Check file status.
   445  	fin, err := file.Stat()
   446  	if err != nil {
   447  		return nil, nil, err
   448  	}
   449  
   450  	// Is this a socket fd.
   451  	if fin.Mode()&os.ModeSocket == 0 {
   452  		return nil, nil, errFileIsNotSocket
   453  	}
   454  
   455  	// tcp, tcp4 or tcp6.
   456  	if listener, err := net.FileListener(file); err == nil {
   457  		return listener, listener.Addr(), nil
   458  	}
   459  
   460  	// udp, udp4 or udp6.
   461  	if packetConn, err := net.FilePacketConn(file); err == nil {
   462  		return packetConn, packetConn.LocalAddr(), nil
   463  	}
   464  
   465  	return nil, nil, errUnSupportedNetworkType
   466  }
   467  
   468  func getPacketConnFd(c net.PacketConn) (*ListenFd, error) {
   469  	sc, ok := c.(syscall.Conn)
   470  	if !ok {
   471  		return nil, fmt.Errorf("getPacketConnFd err: %w", errUnSupportedListenerType)
   472  	}
   473  	lnFd, err := getRawFd(sc)
   474  	if err != nil {
   475  		return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err)
   476  	}
   477  	return &ListenFd{
   478  		Fd:      lnFd,
   479  		Name:    "a udp listener fd",
   480  		Network: c.LocalAddr().Network(),
   481  		Address: c.LocalAddr().String(),
   482  	}, nil
   483  }
   484  
   485  func getListenerFd(ln net.Listener) (*ListenFd, error) {
   486  	sc, ok := ln.(syscall.Conn)
   487  	if !ok {
   488  		return nil, fmt.Errorf("getListenerFd err: %w", errUnSupportedListenerType)
   489  	}
   490  	fd, err := getRawFd(sc)
   491  	if err != nil {
   492  		return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err)
   493  	}
   494  	return &ListenFd{
   495  		Fd:      fd,
   496  		Name:    "a tcp listener fd",
   497  		Network: ln.Addr().Network(),
   498  		Address: ln.Addr().String(),
   499  	}, nil
   500  }
   501  
   502  // getRawFd acts like:
   503  //
   504  //	func (ln *net.TCPListener) (uintptr, error) {
   505  //		f, err := ln.File()
   506  //		if err != nil {
   507  //			return 0, err
   508  //		}
   509  //		fd, err := f.Fd()
   510  //		if err != nil {
   511  //			return 0, err
   512  //		}
   513  //	}
   514  //
   515  // But it differs in an important way:
   516  //
   517  //	The method (*os.File).Fd() will set the original file descriptor to blocking mode as a side effect of fcntl(),
   518  //	which will lead to indefinite hangs of Close/Read/Write, etc.
   519  //
   520  // References:
   521  //   - https://github.com/golang/go/issues/29277
   522  //   - https://github.com/golang/go/issues/29277#issuecomment-447526159
   523  //   - https://github.com/golang/go/issues/29277#issuecomment-448117332
   524  //   - https://github.com/golang/go/issues/43894
   525  func getRawFd(sc syscall.Conn) (uintptr, error) {
   526  	c, err := sc.SyscallConn()
   527  	if err != nil {
   528  		return 0, fmt.Errorf("sc.SyscallConn err: %w", err)
   529  	}
   530  	var lnFd uintptr
   531  	if err := c.Control(func(fd uintptr) {
   532  		lnFd = fd
   533  	}); err != nil {
   534  		return 0, fmt.Errorf("c.Control err: %w", err)
   535  	}
   536  	return lnFd, nil
   537  }