github.com/ianic/xnet/aio@v0.0.0-20230924160527-cee7f41ab201/tcp_listener.go (about)

     1  package aio
     2  
     3  import (
     4  	"log/slog"
     5  	"net"
     6  	"syscall"
     7  	"unsafe"
     8  
     9  	_ "unsafe"
    10  
    11  	"golang.org/x/sys/unix"
    12  )
    13  
    14  type TCPListener struct {
    15  	loop        *Loop
    16  	fd          int
    17  	port        int
    18  	accepted    Accepted
    19  	connections map[int]*TCPConn
    20  }
    21  
    22  func (l *TCPListener) accept() {
    23  	var cb completionCallback
    24  	cb = func(res int32, flags uint32, err *ErrErrno) {
    25  		if err == nil {
    26  			fd := int(res)
    27  			// create new tcp connection and bind it with upstream layer
    28  			tc := newTcpConn(l.loop, func() { delete(l.connections, fd) }, fd)
    29  			l.accepted(fd, tc)
    30  			l.connections[fd] = tc
    31  			return
    32  		}
    33  		if err.Temporary() {
    34  			l.loop.prepareMultishotAccept(l.fd, cb)
    35  			return
    36  		}
    37  		if !err.Canceled() {
    38  			slog.Debug("listener accept", "fd", l.fd, "errno", err, "res", res, "flags", flags)
    39  		}
    40  	}
    41  	l.loop.prepareMultishotAccept(l.fd, cb)
    42  }
    43  
    44  func (l *TCPListener) Close() {
    45  	l.close(true)
    46  }
    47  
    48  func (l *TCPListener) close(shutdownConnections bool) {
    49  	l.loop.prepareCancelFd(l.fd, func(res int32, flags uint32, err *ErrErrno) {
    50  		if err != nil {
    51  			slog.Debug("listener cancel", "fd", l.fd, "err", err, "res", res, "flags", flags)
    52  		}
    53  		if shutdownConnections {
    54  			for _, conn := range l.connections {
    55  				conn.shutdown(ErrListenerClose)
    56  			}
    57  		}
    58  		delete(l.loop.listeners, l.fd)
    59  	})
    60  }
    61  
    62  func socket(sa syscall.Sockaddr) (int, error) {
    63  	domain := syscall.AF_INET
    64  	switch sa.(type) {
    65  	case *syscall.SockaddrInet6:
    66  		domain = syscall.AF_INET6
    67  	}
    68  	return syscall.Socket(domain, syscall.SOCK_STREAM, 0)
    69  }
    70  
    71  func listen(sa syscall.Sockaddr, domain int) (int, int, error) {
    72  	port := 0
    73  	fd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
    74  	if err != nil {
    75  		return 0, 0, err
    76  	}
    77  	if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
    78  		return 0, 0, err
    79  	}
    80  	if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
    81  		return 0, 0, err
    82  	}
    83  	if err := syscall.Bind(fd, sa); err != nil {
    84  		return 0, 0, err
    85  	}
    86  	if port == 0 {
    87  		// get system assigned port
    88  		if sn, err := syscall.Getsockname(fd); err == nil {
    89  			switch v := sn.(type) {
    90  			case *syscall.SockaddrInet4:
    91  				port = v.Port
    92  			case *syscall.SockaddrInet6:
    93  				port = v.Port
    94  			}
    95  		}
    96  	}
    97  	if err := syscall.SetNonblock(fd, false); err != nil {
    98  		return 0, 0, err
    99  	}
   100  	if err := syscall.Listen(fd, 128); err != nil {
   101  		return 0, 0, err
   102  	}
   103  	return fd, port, nil
   104  }
   105  
   106  // resolveTCPAddr converts string address to syscall.Scokaddr interface used in
   107  // other syscall calls.
   108  // "www.google.com:80"
   109  // "[::1]:0"
   110  // "127.0.0.1:1234"
   111  func resolveTCPAddr(addr string) (syscall.Sockaddr, int, error) {
   112  	tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
   113  	if err != nil {
   114  		return nil, 0, err
   115  	}
   116  	ip := tcpAddr.IP
   117  	port := tcpAddr.Port
   118  	if ip4 := ip.To4(); ip4 != nil {
   119  		return &syscall.SockaddrInet4{Port: port, Addr: [4]byte(ip4)}, syscall.AF_INET, nil
   120  	}
   121  	return &syscall.SockaddrInet6{Port: port, Addr: [16]byte(ip)}, syscall.AF_INET6, nil
   122  }
   123  
   124  //go:linkname sockaddr syscall.Sockaddr.sockaddr
   125  func sockaddr(addr syscall.Sockaddr) (unsafe.Pointer, uint32, error)