github.com/cristalhq/netx@v0.0.0-20221116164110-442313ef3309/listener.go (about) 1 package netx 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "os" 9 "syscall" 10 "time" 11 ) 12 13 // TCPListenerConfig is a config TCPListener. 14 type TCPListenerConfig struct { 15 // ReusePort enables SO_REUSEPORT. 16 ReusePort bool 17 18 // DeferAccept enables TCP_DEFER_ACCEPT. 19 DeferAccept bool 20 21 // FastOpen enables TCP_FASTOPEN. 22 FastOpen bool 23 24 // Queue length for TCP_FASTOPEN (default 256). 25 FastOpenQueueLen int 26 27 // Backlog is the maximum number of pending TCP connections the listener 28 // may queue before passing them to Accept. 29 // Default is system-level backlog value is used. 30 Backlog int 31 } 32 33 // TCPListener listens for the addr passed to NewTCPListener. 34 // 35 // It also gathers various stats for the accepted connections. 36 type TCPListener struct { 37 net.Listener 38 cfg TCPListenerConfig 39 stats *Stats 40 } 41 42 // NewTCPListener returns new TCP listener for the given addr. 43 func NewTCPListener(ctx context.Context, network, addr string, cfg TCPListenerConfig) (*TCPListener, error) { 44 ln, err := cfg.newListener(network, addr) 45 if err != nil { 46 return nil, err 47 } 48 49 go func() { 50 <-ctx.Done() 51 ln.Close() 52 }() 53 54 tln := &TCPListener{ 55 Listener: ln, 56 cfg: cfg, 57 stats: &Stats{}, 58 } 59 return tln, err 60 } 61 62 // Accept accepts connections from the addr passed to NewTCPListener. 63 func (ln *TCPListener) Accept() (net.Conn, error) { 64 for { 65 conn, err := ln.Listener.Accept() 66 ln.stats.acceptsInc() 67 if err != nil { 68 var ne net.Error 69 if errors.As(err, &ne) && ne.Timeout() { 70 time.Sleep(10 * time.Millisecond) 71 continue 72 } 73 ln.stats.acceptErrorsInc() 74 return nil, err 75 } 76 77 tcpconn, ok := conn.(*net.TCPConn) 78 if !ok { 79 panic("unreachable") 80 } 81 82 ln.stats.activeConnsInc() 83 sc := &Conn{ 84 TCPConn: *tcpconn, 85 stats: ln.stats, 86 } 87 return sc, nil 88 } 89 } 90 91 // Stats of the listener and accepted connections. 92 func (ln *TCPListener) Stats() *Stats { 93 return ln.stats 94 } 95 96 func (cfg *TCPListenerConfig) newListener(network, addr string) (net.Listener, error) { 97 fd, err := cfg.newSocket(network, addr) 98 if err != nil { 99 return nil, err 100 } 101 102 name := fmt.Sprintf("netx.%d.%s.%s", os.Getpid(), network, addr) 103 file := os.NewFile(uintptr(fd), name) 104 105 ln, err := net.FileListener(file) 106 if err != nil { 107 file.Close() 108 return nil, err 109 } 110 111 if err := file.Close(); err != nil { 112 ln.Close() 113 return nil, err 114 } 115 return ln, nil 116 } 117 118 func (cfg *TCPListenerConfig) newSocket(network, addr string) (fd int, err error) { 119 sa, domain, err := getTCPSockaddr(network, addr) 120 if err != nil { 121 return 0, err 122 } 123 124 fd, err = newSocketCloexec(domain, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) 125 if err != nil { 126 return 0, err 127 } 128 129 if err := cfg.fdSetup(fd, sa, addr); err != nil { 130 syscall.Close(fd) 131 return 0, err 132 } 133 return fd, nil 134 } 135 136 func (cfg *TCPListenerConfig) fdSetup(fd int, sa syscall.Sockaddr, addr string) error { 137 if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)); err != nil { 138 return fmt.Errorf("cannot enable SO_REUSEADDR: %s", err) 139 } 140 141 // This should disable Nagle's algorithm in all accepted sockets by default. 142 // Users may enable it with net.TCPConn.SetNoDelay(false). 143 if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, 1)); err != nil { 144 return fmt.Errorf("cannot disable Nagle's algorithm: %s", err) 145 } 146 147 if cfg.ReusePort { 148 if err := newError("setsockopt", syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1)); err != nil { 149 return fmt.Errorf("cannot enable SO_REUSEPORT: %s", err) 150 } 151 } 152 153 if cfg.DeferAccept { 154 if err := enableDeferAccept(fd); err != nil { 155 return err 156 } 157 } 158 159 if cfg.FastOpen { 160 if err := enableFastOpen(fd, cfg.FastOpenQueueLen); err != nil { 161 return err 162 } 163 } 164 165 if err := newError("bind", syscall.Bind(fd, sa)); err != nil { 166 return fmt.Errorf("cannot bind to %q: %s", addr, err) 167 } 168 169 backlog := cfg.Backlog 170 if backlog <= 0 { 171 var err error 172 if backlog, err = soMaxConn(); err != nil { 173 return fmt.Errorf("cannot determine backlog to pass to listen(2): %s", err) 174 } 175 } 176 if err := newError("listen", syscall.Listen(fd, backlog)); err != nil { 177 return fmt.Errorf("cannot listen on %q: %s", addr, err) 178 } 179 180 return nil 181 } 182 183 func newSocketCloexecDefault(domain, typ, proto int) (int, error) { 184 syscall.ForkLock.RLock() 185 fd, err := syscall.Socket(domain, typ, proto) 186 if err == nil { 187 syscall.CloseOnExec(fd) 188 } 189 syscall.ForkLock.RUnlock() 190 191 if err != nil { 192 return -1, fmt.Errorf("cannot create listening socket: %s", err) 193 } 194 195 // TODO(oleg): move to fdSetup ? 196 if err := newError("setnonblock", syscall.SetNonblock(fd, true)); err != nil { 197 syscall.Close(fd) 198 return -1, fmt.Errorf("cannot make non-blocked listening socket: %s", err) 199 } 200 return fd, nil 201 } 202 203 func getTCPSockaddr(network, addr string) (sa syscall.Sockaddr, domain int, err error) { 204 tcp, err := net.ResolveTCPAddr(network, addr) 205 if err != nil { 206 return nil, -1, err 207 } 208 209 switch network { 210 case "tcp": 211 return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil 212 case "tcp4": 213 sa := &syscall.SockaddrInet4{Port: tcp.Port} 214 if tcp.IP != nil { 215 if len(tcp.IP) == 16 { 216 copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array 217 } else { 218 copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array 219 } 220 } 221 return sa, syscall.AF_INET, nil 222 case "tcp6": 223 sa := &syscall.SockaddrInet6{Port: tcp.Port} 224 225 if tcp.IP != nil { 226 copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array 227 } 228 229 if tcp.Zone != "" { 230 iface, err := net.InterfaceByName(tcp.Zone) 231 if err != nil { 232 return nil, -1, err 233 } 234 235 sa.ZoneId = uint32(iface.Index) 236 } 237 return sa, syscall.AF_INET6, nil 238 default: 239 panic("unreachable") 240 } 241 }