trpc.group/trpc-go/trpc-go@v1.0.2/transport/server_transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"os"
    24  	"runtime"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  	"syscall"
    29  	"time"
    30  
    31  	"github.com/panjf2000/ants/v2"
    32  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    33  
    34  	itls "trpc.group/trpc-go/trpc-go/internal/tls"
    35  	"trpc.group/trpc-go/trpc-go/log"
    36  )
    37  
    38  const transportName = "go-net"
    39  
    40  func init() {
    41  	RegisterServerTransport(transportName, DefaultServerStreamTransport)
    42  }
    43  
    44  const (
    45  	// EnvGraceRestart is the flag of graceful restart.
    46  	EnvGraceRestart = "TRPC_IS_GRACEFUL"
    47  
    48  	// EnvGraceFirstFd is the fd of graceful first listener.
    49  	EnvGraceFirstFd = "TRPC_GRACEFUL_1ST_LISTENFD"
    50  
    51  	// EnvGraceRestartFdNum is the number of fd for graceful restart.
    52  	EnvGraceRestartFdNum = "TRPC_GRACEFUL_LISTENFD_NUM"
    53  
    54  	// EnvGraceRestartPPID is the PPID of graceful restart.
    55  	EnvGraceRestartPPID = "TRPC_GRACEFUL_PPID"
    56  )
    57  
    58  var (
    59  	errUnSupportedListenerType = errors.New("not supported listener type")
    60  	errUnSupportedNetworkType  = errors.New("not supported network type")
    61  	errFileIsNotSocket         = errors.New("file is not a socket")
    62  )
    63  
    64  // DefaultServerTransport is the default implementation of ServerStreamTransport.
    65  var DefaultServerTransport = NewServerTransport(WithReusePort(true))
    66  
    67  // NewServerTransport creates a new ServerTransport.
    68  func NewServerTransport(opt ...ServerTransportOption) ServerTransport {
    69  	r := newServerTransport(opt...)
    70  	return &r
    71  }
    72  
    73  // newServerTransport creates a new serverTransport.
    74  func newServerTransport(opt ...ServerTransportOption) serverTransport {
    75  	// this is the default option.
    76  	opts := defaultServerTransportOptions()
    77  	for _, o := range opt {
    78  		o(opts)
    79  	}
    80  	addrToConn := make(map[string]*tcpconn)
    81  	return serverTransport{addrToConn: addrToConn, m: &sync.RWMutex{}, opts: opts}
    82  }
    83  
    84  // serverTransport is the implementation details of server transport, may be tcp or udp.
    85  type serverTransport struct {
    86  	addrToConn map[string]*tcpconn
    87  	m          *sync.RWMutex
    88  	opts       *ServerTransportOptions
    89  }
    90  
    91  // ListenAndServe starts Listening, returns an error on failure.
    92  func (s *serverTransport) ListenAndServe(ctx context.Context, opts ...ListenServeOption) error {
    93  	lsopts := &ListenServeOptions{}
    94  	for _, opt := range opts {
    95  		opt(lsopts)
    96  	}
    97  
    98  	if lsopts.Listener != nil {
    99  		return s.listenAndServeStream(ctx, lsopts)
   100  	}
   101  	// Support simultaneous listening TCP and UDP.
   102  	networks := strings.Split(lsopts.Network, ",")
   103  	for _, network := range networks {
   104  		lsopts.Network = network
   105  		switch lsopts.Network {
   106  		case "tcp", "tcp4", "tcp6", "unix":
   107  			if err := s.listenAndServeStream(ctx, lsopts); err != nil {
   108  				return err
   109  			}
   110  		case "udp", "udp4", "udp6":
   111  			if err := s.listenAndServePacket(ctx, lsopts); err != nil {
   112  				return err
   113  			}
   114  		default:
   115  			return fmt.Errorf("server transport: not support network type %s", lsopts.Network)
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  // ---------------------------------stream server-----------------------------------------//
   122  
   123  var (
   124  	// listenersMap records the listeners in use in the current process.
   125  	listenersMap = &sync.Map{}
   126  	// inheritedListenersMap record the listeners inherited from the parent process.
   127  	// A key(host:port) may have multiple listener fds.
   128  	inheritedListenersMap = &sync.Map{}
   129  	// once controls fds passed from parent process to construct listeners.
   130  	once sync.Once
   131  )
   132  
   133  // GetListenersFds gets listener fds.
   134  func GetListenersFds() []*ListenFd {
   135  	listenersFds := []*ListenFd{}
   136  	listenersMap.Range(func(key, _ interface{}) bool {
   137  		var (
   138  			fd  *ListenFd
   139  			err error
   140  		)
   141  
   142  		switch k := key.(type) {
   143  		case net.Listener:
   144  			fd, err = getListenerFd(k)
   145  		case net.PacketConn:
   146  			fd, err = getPacketConnFd(k)
   147  		default:
   148  			log.Errorf("listener type passing not supported, type: %T", key)
   149  			err = fmt.Errorf("not supported listener type: %T", key)
   150  		}
   151  		if err != nil {
   152  			log.Errorf("cannot get the listener fd, err: %v", err)
   153  			return true
   154  		}
   155  		listenersFds = append(listenersFds, fd)
   156  		return true
   157  	})
   158  	return listenersFds
   159  }
   160  
   161  // SaveListener saves the listener.
   162  func SaveListener(listener interface{}) error {
   163  	switch listener.(type) {
   164  	case net.Listener, net.PacketConn:
   165  		listenersMap.Store(listener, struct{}{})
   166  	default:
   167  		return fmt.Errorf("not supported listener type: %T", listener)
   168  	}
   169  	return nil
   170  }
   171  
   172  // getTCPListener gets the TCP/Unix listener.
   173  func (s *serverTransport) getTCPListener(opts *ListenServeOptions) (listener net.Listener, err error) {
   174  	listener = opts.Listener
   175  
   176  	if listener != nil {
   177  		return listener, nil
   178  	}
   179  
   180  	v, _ := os.LookupEnv(EnvGraceRestart)
   181  	ok, _ := strconv.ParseBool(v)
   182  	if ok {
   183  		// find the passed listener
   184  		pln, err := getPassedListener(opts.Network, opts.Address)
   185  		if err != nil {
   186  			return nil, err
   187  		}
   188  
   189  		listener, ok := pln.(net.Listener)
   190  		if !ok {
   191  			return nil, errors.New("invalid net.Listener")
   192  		}
   193  		return listener, nil
   194  	}
   195  
   196  	// Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads.
   197  	if s.opts.ReusePort && opts.Network != "unix" {
   198  		listener, err = reuseport.Listen(opts.Network, opts.Address)
   199  		if err != nil {
   200  			return nil, fmt.Errorf("%s reuseport error:%v", opts.Network, err)
   201  		}
   202  	} else {
   203  		listener, err = net.Listen(opts.Network, opts.Address)
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  	}
   208  
   209  	return listener, nil
   210  }
   211  
   212  // listenAndServeStream starts listening, returns an error on failure.
   213  func (s *serverTransport) listenAndServeStream(ctx context.Context, opts *ListenServeOptions) error {
   214  	if opts.FramerBuilder == nil {
   215  		return errors.New("tcp transport FramerBuilder empty")
   216  	}
   217  	ln, err := s.getTCPListener(opts)
   218  	if err != nil {
   219  		return fmt.Errorf("get tcp listener err: %w", err)
   220  	}
   221  	// We MUST save the raw TCP listener (instead of (*tls.listener) if TLS is enabled)
   222  	// to guarantee the underlying fd can be successfully retrieved for hot restart.
   223  	listenersMap.Store(ln, struct{}{})
   224  	ln, err = mayLiftToTLSListener(ln, opts)
   225  	if err != nil {
   226  		return fmt.Errorf("may lift to tls listener err: %w", err)
   227  	}
   228  	go s.serveStream(ctx, ln, opts)
   229  	return nil
   230  }
   231  
   232  func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listener, error) {
   233  	if !(len(opts.TLSCertFile) > 0 && len(opts.TLSKeyFile) > 0) {
   234  		return ln, nil
   235  	}
   236  	// Enable TLS.
   237  	tlsConf, err := itls.GetServerConfig(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile)
   238  	if err != nil {
   239  		return nil, fmt.Errorf("tls get server config err: %w", err)
   240  	}
   241  	return tls.NewListener(ln, tlsConf), nil
   242  }
   243  
   244  func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error {
   245  	return s.serveTCP(ctx, ln, opts)
   246  }
   247  
   248  // ---------------------------------packet server-----------------------------------------//
   249  
   250  // listenAndServePacket starts listening, returns an error on failure.
   251  func (s *serverTransport) listenAndServePacket(ctx context.Context, opts *ListenServeOptions) error {
   252  	pool := createUDPRoutinePool(opts.Routines)
   253  	// Reuse port. To speed up IO, the kernel dispatches IO ReadReady events to threads.
   254  	if s.opts.ReusePort {
   255  		reuseport.ListenerBacklogMaxSize = 4096
   256  		cores := runtime.NumCPU()
   257  		for i := 0; i < cores; i++ {
   258  			udpconn, err := s.getUDPListener(opts)
   259  			if err != nil {
   260  				return err
   261  			}
   262  			listenersMap.Store(udpconn, struct{}{})
   263  
   264  			go s.servePacket(ctx, udpconn, pool, opts)
   265  		}
   266  	} else {
   267  		udpconn, err := s.getUDPListener(opts)
   268  		if err != nil {
   269  			return err
   270  		}
   271  		listenersMap.Store(udpconn, struct{}{})
   272  
   273  		go s.servePacket(ctx, udpconn, pool, opts)
   274  	}
   275  	return nil
   276  }
   277  
   278  // getUDPListener gets UDP listener.
   279  func (s *serverTransport) getUDPListener(opts *ListenServeOptions) (udpConn net.PacketConn, err error) {
   280  	v, _ := os.LookupEnv(EnvGraceRestart)
   281  	ok, _ := strconv.ParseBool(v)
   282  	if ok {
   283  		// Find the passed listener.
   284  		ln, err := getPassedListener(opts.Network, opts.Address)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  		listener, ok := ln.(net.PacketConn)
   289  		if !ok {
   290  			return nil, errors.New("invalid net.PacketConn")
   291  		}
   292  		return listener, nil
   293  	}
   294  
   295  	if s.opts.ReusePort {
   296  		udpConn, err = reuseport.ListenPacket(opts.Network, opts.Address)
   297  		if err != nil {
   298  			return nil, fmt.Errorf("udp reuseport error:%v", err)
   299  		}
   300  	} else {
   301  		udpConn, err = net.ListenPacket(opts.Network, opts.Address)
   302  		if err != nil {
   303  			return nil, fmt.Errorf("udp listen error:%v", err)
   304  		}
   305  	}
   306  
   307  	return udpConn, nil
   308  }
   309  
   310  func (s *serverTransport) servePacket(ctx context.Context, rwc net.PacketConn, pool *ants.PoolWithFunc,
   311  	opts *ListenServeOptions) error {
   312  	switch rwc := rwc.(type) {
   313  	case *net.UDPConn:
   314  		return s.serveUDP(ctx, rwc, pool, opts)
   315  	default:
   316  		return errors.New("transport not support PacketConn impl")
   317  	}
   318  }
   319  
   320  // ------------------------ tcp/udp connection structures ----------------------------//
   321  
   322  func (s *serverTransport) newConn(ctx context.Context, opts *ListenServeOptions) *conn {
   323  	idleTimeout := opts.IdleTimeout
   324  	if s.opts.IdleTimeout > 0 {
   325  		idleTimeout = s.opts.IdleTimeout
   326  	}
   327  	return &conn{
   328  		ctx:         ctx,
   329  		handler:     opts.Handler,
   330  		idleTimeout: idleTimeout,
   331  	}
   332  }
   333  
   334  // conn is the struct of connection which is established when server receive a client connecting
   335  // request.
   336  type conn struct {
   337  	ctx         context.Context
   338  	cancelCtx   context.CancelFunc
   339  	idleTimeout time.Duration
   340  	lastVisited time.Time
   341  	handler     Handler
   342  }
   343  
   344  func (c *conn) handle(ctx context.Context, req []byte) ([]byte, error) {
   345  	return c.handler.Handle(ctx, req)
   346  }
   347  
   348  func (c *conn) handleClose(ctx context.Context) error {
   349  	if closeHandler, ok := c.handler.(CloseHandler); ok {
   350  		return closeHandler.HandleClose(ctx)
   351  	}
   352  	return nil
   353  }
   354  
   355  var errNotFound = errors.New("listener not found")
   356  
   357  // GetPassedListener gets the inherited listener from parent process by network and address.
   358  func GetPassedListener(network, address string) (interface{}, error) {
   359  	return getPassedListener(network, address)
   360  }
   361  
   362  func getPassedListener(network, address string) (interface{}, error) {
   363  	once.Do(inheritListeners)
   364  
   365  	key := network + ":" + address
   366  	v, ok := inheritedListenersMap.Load(key)
   367  	if !ok {
   368  		return nil, errNotFound
   369  	}
   370  
   371  	listeners := v.([]interface{})
   372  	if len(listeners) == 0 {
   373  		return nil, errNotFound
   374  	}
   375  
   376  	ln := listeners[0]
   377  	listeners = listeners[1:]
   378  	if len(listeners) == 0 {
   379  		inheritedListenersMap.Delete(key)
   380  	} else {
   381  		inheritedListenersMap.Store(key, listeners)
   382  	}
   383  
   384  	return ln, nil
   385  }
   386  
   387  // ListenFd is the listener fd.
   388  type ListenFd struct {
   389  	OriginalListenCloser io.Closer
   390  	Fd                   uintptr
   391  	Name                 string
   392  	Network              string
   393  	Address              string
   394  }
   395  
   396  // inheritListeners stores the listener according to start listenfd and number of listenfd passed
   397  // by environment variables.
   398  func inheritListeners() {
   399  	firstListenFd, err := strconv.ParseUint(os.Getenv(EnvGraceFirstFd), 10, 32)
   400  	if err != nil {
   401  		log.Errorf("invalid %s, error: %v", EnvGraceFirstFd, err)
   402  	}
   403  
   404  	num, err := strconv.ParseUint(os.Getenv(EnvGraceRestartFdNum), 10, 32)
   405  	if err != nil {
   406  		log.Errorf("invalid %s, error: %v", EnvGraceRestartFdNum, err)
   407  	}
   408  
   409  	for fd := firstListenFd; fd < firstListenFd+num; fd++ {
   410  		file := os.NewFile(uintptr(fd), "")
   411  		listener, addr, err := fileListener(file)
   412  		file.Close()
   413  		if err != nil {
   414  			log.Errorf("get file listener error: %v", err)
   415  			continue
   416  		}
   417  
   418  		key := addr.Network() + ":" + addr.String()
   419  		v, ok := inheritedListenersMap.LoadOrStore(key, []interface{}{listener})
   420  		if ok {
   421  			listeners := v.([]interface{})
   422  			listeners = append(listeners, listener)
   423  			inheritedListenersMap.Store(key, listeners)
   424  		}
   425  	}
   426  }
   427  
   428  func fileListener(file *os.File) (interface{}, net.Addr, error) {
   429  	// Check file status.
   430  	fin, err := file.Stat()
   431  	if err != nil {
   432  		return nil, nil, err
   433  	}
   434  
   435  	// Is this a socket fd.
   436  	if fin.Mode()&os.ModeSocket == 0 {
   437  		return nil, nil, errFileIsNotSocket
   438  	}
   439  
   440  	// tcp, tcp4 or tcp6.
   441  	if listener, err := net.FileListener(file); err == nil {
   442  		return listener, listener.Addr(), nil
   443  	}
   444  
   445  	// udp, udp4 or udp6.
   446  	if packetConn, err := net.FilePacketConn(file); err == nil {
   447  		return packetConn, packetConn.LocalAddr(), nil
   448  	}
   449  
   450  	return nil, nil, errUnSupportedNetworkType
   451  }
   452  
   453  func getPacketConnFd(c net.PacketConn) (*ListenFd, error) {
   454  	sc, ok := c.(syscall.Conn)
   455  	if !ok {
   456  		return nil, fmt.Errorf("getPacketConnFd err: %w", errUnSupportedListenerType)
   457  	}
   458  	lnFd, err := getRawFd(sc)
   459  	if err != nil {
   460  		return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err)
   461  	}
   462  	return &ListenFd{
   463  		OriginalListenCloser: c,
   464  		Fd:                   lnFd,
   465  		Name:                 "a udp listener fd",
   466  		Network:              c.LocalAddr().Network(),
   467  		Address:              c.LocalAddr().String(),
   468  	}, nil
   469  }
   470  
   471  func getListenerFd(ln net.Listener) (*ListenFd, error) {
   472  	sc, ok := ln.(syscall.Conn)
   473  	if !ok {
   474  		return nil, fmt.Errorf("getListenerFd err: %w", errUnSupportedListenerType)
   475  	}
   476  	fd, err := getRawFd(sc)
   477  	if err != nil {
   478  		return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err)
   479  	}
   480  	return &ListenFd{
   481  		OriginalListenCloser: ln,
   482  		Fd:                   fd,
   483  		Name:                 "a tcp listener fd",
   484  		Network:              ln.Addr().Network(),
   485  		Address:              ln.Addr().String(),
   486  	}, nil
   487  }
   488  
   489  // getRawFd acts like:
   490  //
   491  //	func (ln *net.TCPListener) (uintptr, error) {
   492  //		f, err := ln.File()
   493  //		if err != nil {
   494  //			return 0, err
   495  //		}
   496  //		fd, err := f.Fd()
   497  //		if err != nil {
   498  //			return 0, err
   499  //		}
   500  //	}
   501  //
   502  // But it differs in an important way:
   503  //
   504  //	The method (*os.File).Fd() will set the original file descriptor to blocking mode as a side effect of fcntl(),
   505  //	which will lead to indefinite hangs of Close/Read/Write, etc.
   506  //
   507  // References:
   508  //   - https://github.com/golang/go/issues/29277
   509  //   - https://github.com/golang/go/issues/29277#issuecomment-447526159
   510  //   - https://github.com/golang/go/issues/29277#issuecomment-448117332
   511  //   - https://github.com/golang/go/issues/43894
   512  func getRawFd(sc syscall.Conn) (uintptr, error) {
   513  	c, err := sc.SyscallConn()
   514  	if err != nil {
   515  		return 0, fmt.Errorf("sc.SyscallConn err: %w", err)
   516  	}
   517  	var lnFd uintptr
   518  	if err := c.Control(func(fd uintptr) {
   519  		lnFd = fd
   520  	}); err != nil {
   521  		return 0, fmt.Errorf("c.Control err: %w", err)
   522  	}
   523  	return lnFd, nil
   524  }