github.com/cristalhq/netx@v0.0.0-20221116164110-442313ef3309/listener.go (about)

     1  package netx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"os"
     9  	"syscall"
    10  	"time"
    11  )
    12  
    13  // TCPListenerConfig is a config TCPListener.
    14  type TCPListenerConfig struct {
    15  	// ReusePort enables SO_REUSEPORT.
    16  	ReusePort bool
    17  
    18  	// DeferAccept enables TCP_DEFER_ACCEPT.
    19  	DeferAccept bool
    20  
    21  	// FastOpen enables TCP_FASTOPEN.
    22  	FastOpen bool
    23  
    24  	// Queue length for TCP_FASTOPEN (default 256).
    25  	FastOpenQueueLen int
    26  
    27  	// Backlog is the maximum number of pending TCP connections the listener
    28  	// may queue before passing them to Accept.
    29  	// Default is system-level backlog value is used.
    30  	Backlog int
    31  }
    32  
    33  // TCPListener listens for the addr passed to NewTCPListener.
    34  //
    35  // It also gathers various stats for the accepted connections.
    36  type TCPListener struct {
    37  	net.Listener
    38  	cfg   TCPListenerConfig
    39  	stats *Stats
    40  }
    41  
    42  // NewTCPListener returns new TCP listener for the given addr.
    43  func NewTCPListener(ctx context.Context, network, addr string, cfg TCPListenerConfig) (*TCPListener, error) {
    44  	ln, err := cfg.newListener(network, addr)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	go func() {
    50  		<-ctx.Done()
    51  		ln.Close()
    52  	}()
    53  
    54  	tln := &TCPListener{
    55  		Listener: ln,
    56  		cfg:      cfg,
    57  		stats:    &Stats{},
    58  	}
    59  	return tln, err
    60  }
    61  
    62  // Accept accepts connections from the addr passed to NewTCPListener.
    63  func (ln *TCPListener) Accept() (net.Conn, error) {
    64  	for {
    65  		conn, err := ln.Listener.Accept()
    66  		ln.stats.acceptsInc()
    67  		if err != nil {
    68  			var ne net.Error
    69  			if errors.As(err, &ne) && ne.Timeout() {
    70  				time.Sleep(10 * time.Millisecond)
    71  				continue
    72  			}
    73  			ln.stats.acceptErrorsInc()
    74  			return nil, err
    75  		}
    76  
    77  		tcpconn, ok := conn.(*net.TCPConn)
    78  		if !ok {
    79  			panic("unreachable")
    80  		}
    81  
    82  		ln.stats.activeConnsInc()
    83  		sc := &Conn{
    84  			TCPConn: *tcpconn,
    85  			stats:   ln.stats,
    86  		}
    87  		return sc, nil
    88  	}
    89  }
    90  
    91  // Stats of the listener and accepted connections.
    92  func (ln *TCPListener) Stats() *Stats {
    93  	return ln.stats
    94  }
    95  
    96  func (cfg *TCPListenerConfig) newListener(network, addr string) (net.Listener, error) {
    97  	fd, err := cfg.newSocket(network, addr)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	name := fmt.Sprintf("netx.%d.%s.%s", os.Getpid(), network, addr)
   103  	file := os.NewFile(uintptr(fd), name)
   104  
   105  	ln, err := net.FileListener(file)
   106  	if err != nil {
   107  		file.Close()
   108  		return nil, err
   109  	}
   110  
   111  	if err := file.Close(); err != nil {
   112  		ln.Close()
   113  		return nil, err
   114  	}
   115  	return ln, nil
   116  }
   117  
   118  func (cfg *TCPListenerConfig) newSocket(network, addr string) (fd int, err error) {
   119  	sa, domain, err := getTCPSockaddr(network, addr)
   120  	if err != nil {
   121  		return 0, err
   122  	}
   123  
   124  	fd, err = newSocketCloexec(domain, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
   125  	if err != nil {
   126  		return 0, err
   127  	}
   128  
   129  	if err := cfg.fdSetup(fd, sa, addr); err != nil {
   130  		syscall.Close(fd)
   131  		return 0, err
   132  	}
   133  	return fd, nil
   134  }
   135  
   136  func (cfg *TCPListenerConfig) fdSetup(fd int, sa syscall.Sockaddr, addr string) error {
   137  	if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)); err != nil {
   138  		return fmt.Errorf("cannot enable SO_REUSEADDR: %s", err)
   139  	}
   140  
   141  	// This should disable Nagle's algorithm in all accepted sockets by default.
   142  	// Users may enable it with net.TCPConn.SetNoDelay(false).
   143  	if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1)); err != nil {
   144  		return fmt.Errorf("cannot disable Nagle's algorithm: %s", err)
   145  	}
   146  
   147  	if cfg.ReusePort {
   148  		if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1)); err != nil {
   149  			return fmt.Errorf("cannot enable SO_REUSEPORT: %s", err)
   150  		}
   151  	}
   152  
   153  	if cfg.DeferAccept {
   154  		if err := enableDeferAccept(fd); err != nil {
   155  			return err
   156  		}
   157  	}
   158  
   159  	if cfg.FastOpen {
   160  		if err := enableFastOpen(fd, cfg.FastOpenQueueLen); err != nil {
   161  			return err
   162  		}
   163  	}
   164  
   165  	if err := newError("bind", syscall.Bind(fd, sa)); err != nil {
   166  		return fmt.Errorf("cannot bind to %q: %s", addr, err)
   167  	}
   168  
   169  	backlog := cfg.Backlog
   170  	if backlog <= 0 {
   171  		var err error
   172  		if backlog, err = soMaxConn(); err != nil {
   173  			return fmt.Errorf("cannot determine backlog to pass to listen(2): %s", err)
   174  		}
   175  	}
   176  	if err := newError("listen", syscall.Listen(fd, backlog)); err != nil {
   177  		return fmt.Errorf("cannot listen on %q: %s", addr, err)
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  func newSocketCloexecDefault(domain, typ, proto int) (int, error) {
   184  	syscall.ForkLock.RLock()
   185  	fd, err := syscall.Socket(domain, typ, proto)
   186  	if err == nil {
   187  		syscall.CloseOnExec(fd)
   188  	}
   189  	syscall.ForkLock.RUnlock()
   190  
   191  	if err != nil {
   192  		return -1, fmt.Errorf("cannot create listening socket: %s", err)
   193  	}
   194  
   195  	// TODO(oleg): move to fdSetup ?
   196  	if err := newError("setnonblock", syscall.SetNonblock(fd, true)); err != nil {
   197  		syscall.Close(fd)
   198  		return -1, fmt.Errorf("cannot make non-blocked listening socket: %s", err)
   199  	}
   200  	return fd, nil
   201  }
   202  
   203  func getTCPSockaddr(network, addr string) (sa syscall.Sockaddr, domain int, err error) {
   204  	tcp, err := net.ResolveTCPAddr(network, addr)
   205  	if err != nil {
   206  		return nil, -1, err
   207  	}
   208  
   209  	switch network {
   210  	case "tcp":
   211  		return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil
   212  	case "tcp4":
   213  		sa := &syscall.SockaddrInet4{Port: tcp.Port}
   214  		if tcp.IP != nil {
   215  			if len(tcp.IP) == 16 {
   216  				copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array
   217  			} else {
   218  				copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array
   219  			}
   220  		}
   221  		return sa, syscall.AF_INET, nil
   222  	case "tcp6":
   223  		sa := &syscall.SockaddrInet6{Port: tcp.Port}
   224  
   225  		if tcp.IP != nil {
   226  			copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array
   227  		}
   228  
   229  		if tcp.Zone != "" {
   230  			iface, err := net.InterfaceByName(tcp.Zone)
   231  			if err != nil {
   232  				return nil, -1, err
   233  			}
   234  
   235  			sa.ZoneId = uint32(iface.Index)
   236  		}
   237  		return sa, syscall.AF_INET6, nil
   238  	default:
   239  		panic("unreachable")
   240  	}
   241  }