github.com/ianic/xnet/aio@v0.0.0-20230924160527-cee7f41ab201/loop.go (about)

     1  package aio
     2  
     3  import (
     4  	"context"
     5  	"log/slog"
     6  	"math"
     7  	"os"
     8  	"runtime"
     9  	"syscall"
    10  	"time"
    11  	"unsafe"
    12  
    13  	"github.com/pawelgaczynski/giouring"
    14  )
    15  
    16  const (
    17  	batchSize      = 128
    18  	buffersGroupID = 0 // currently using only 1 provided buffer group
    19  )
    20  
    21  type completionCallback = func(res int32, flags uint32, err *ErrErrno)
    22  type operation = func(*giouring.SubmissionQueueEntry)
    23  
    24  type Loop struct {
    25  	ring      *giouring.Ring
    26  	callbacks callbacks
    27  	buffers   providedBuffers
    28  	pending   []operation
    29  
    30  	listeners   map[int]*TCPListener
    31  	connections map[int]*TCPConn
    32  }
    33  
    34  type Options struct {
    35  	RingEntries      uint32
    36  	RecvBuffersCount uint32
    37  	RecvBufferLen    uint32
    38  }
    39  
    40  var DefaultOptions = Options{
    41  	RingEntries:      1024,
    42  	RecvBuffersCount: 256,
    43  	RecvBufferLen:    4 * 1024,
    44  }
    45  
    46  func New(opt Options) (*Loop, error) {
    47  	ring, err := giouring.CreateRing(opt.RingEntries)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  	l := &Loop{
    52  		ring:        ring,
    53  		listeners:   make(map[int]*TCPListener),
    54  		connections: make(map[int]*TCPConn),
    55  	}
    56  	l.callbacks.init()
    57  	if err := l.buffers.init(ring, opt.RecvBuffersCount, opt.RecvBufferLen); err != nil {
    58  		return nil, err
    59  	}
    60  	return l, nil
    61  }
    62  
    63  // runOnce performs one loop run.
    64  // Submits all prepared operations to the kernel and waits for at least one
    65  // completed operation by the kernel.
    66  func (l *Loop) runOnce() error {
    67  	if err := l.submitAndWait(1); err != nil {
    68  		return err
    69  	}
    70  	_ = l.flushCompletions()
    71  	return nil
    72  }
    73  
    74  // runUntilDone runs loop until all prepared operations are finished.
    75  func (l *Loop) runUntilDone() error {
    76  	for {
    77  		if l.callbacks.count() == 0 {
    78  			if len(l.connections) > 0 || len(l.listeners) > 0 {
    79  				panic("unclean shutdown")
    80  			}
    81  			return nil
    82  		}
    83  		if err := l.runOnce(); err != nil {
    84  			return err
    85  		}
    86  	}
    87  }
    88  
    89  // Run runs loop until ctx is cancelled. Then performs clean shutdown.
    90  // After ctx is done it closes all pending listeners and dialed connections.
    91  // Listener will first stop listening then close all accepted connections.
    92  // Loop will wait for all operations to finish.
    93  func (l *Loop) Run(ctx context.Context) error {
    94  	// run until ctx is done
    95  	if err := l.runCtx(ctx, time.Millisecond*333); err != nil {
    96  		return err
    97  	}
    98  	l.closePendingConnections()
    99  	// run loop until all operations finishes
   100  	if err := l.runUntilDone(); err != nil {
   101  		return err
   102  	}
   103  	return nil
   104  }
   105  
   106  func (l *Loop) closePendingConnections() {
   107  	for _, lsn := range l.listeners {
   108  		lsn.Close()
   109  	}
   110  	for _, conn := range l.connections {
   111  		conn.Close()
   112  	}
   113  }
   114  
   115  // runCtx runs loop until context is canceled.
   116  // Checks context every `timeout`.
   117  func (l *Loop) runCtx(ctx context.Context, timeout time.Duration) error {
   118  	ts := syscall.NsecToTimespec(int64(timeout))
   119  	done := func() bool {
   120  		select {
   121  		case <-ctx.Done():
   122  			return true
   123  		default:
   124  		}
   125  		return false
   126  	}
   127  	for {
   128  		if err := l.submit(); err != nil {
   129  			return err
   130  		}
   131  		if _, err := l.ring.WaitCQEs(1, &ts, nil); err != nil && !TemporaryError(err) {
   132  			return err
   133  		}
   134  		_ = l.flushCompletions()
   135  		if done() {
   136  			break
   137  		}
   138  	}
   139  	return nil
   140  }
   141  
   142  // TemporaryError returns true if syscall.Errno should be threated as temporary.
   143  func TemporaryError(err error) bool {
   144  	if errno, ok := err.(syscall.Errno); ok {
   145  		return (&ErrErrno{Errno: errno}).Temporary()
   146  	}
   147  	if os.IsTimeout(err) {
   148  		return true
   149  	}
   150  	return false
   151  }
   152  
   153  // Retries on temporary errors.
   154  // Anything not handled here is fatal and application should terminate.
   155  // Errors that can be returned by [io_uring_enter].
   156  //
   157  // [io_uring_enter]: https://manpages.debian.org/unstable/liburing-dev/io_uring_enter.2.en.html#ERRORS
   158  func (l *Loop) submitAndWait(waitNr uint32) error {
   159  	for {
   160  		if len(l.pending) > 0 {
   161  			_, err := l.ring.SubmitAndWait(0)
   162  			if err == nil {
   163  				l.preparePending()
   164  			}
   165  		}
   166  
   167  		_, err := l.ring.SubmitAndWait(waitNr)
   168  		if err != nil && TemporaryError(err) {
   169  			continue
   170  		}
   171  		return err
   172  	}
   173  }
   174  
   175  func (l *Loop) preparePending() {
   176  	prepared := 0
   177  	for _, op := range l.pending {
   178  		sqe := l.ring.GetSQE()
   179  		if sqe == nil {
   180  			break
   181  		}
   182  		op(sqe)
   183  		prepared++
   184  	}
   185  	if prepared == len(l.pending) {
   186  		l.pending = nil
   187  	} else {
   188  		l.pending = l.pending[prepared:]
   189  	}
   190  }
   191  
   192  func (l *Loop) submit() error {
   193  	return l.submitAndWait(0)
   194  }
   195  
   196  func (l *Loop) flushCompletions() uint32 {
   197  	var cqes [batchSize]*giouring.CompletionQueueEvent
   198  	var noCompleted uint32 = 0
   199  	for {
   200  		peeked := l.ring.PeekBatchCQE(cqes[:])
   201  		for _, cqe := range cqes[:peeked] {
   202  			err := cqeErr(cqe)
   203  			if cqe.UserData == 0 {
   204  				slog.Debug("ceq without userdata", "res", cqe.Res, "flags", cqe.Flags, "err", err)
   205  				continue
   206  			}
   207  			cb := l.callbacks.get(cqe)
   208  			cb(cqe.Res, cqe.Flags, err)
   209  		}
   210  		l.ring.CQAdvance(peeked)
   211  		noCompleted += peeked
   212  		if peeked < uint32(len(cqes)) {
   213  			return noCompleted
   214  		}
   215  	}
   216  }
   217  
   218  func (l *Loop) Close() {
   219  	l.ring.QueueExit()
   220  	l.buffers.deinit()
   221  }
   222  
   223  // prepares operation or adds it to pending if can't get sqe
   224  func (l *Loop) prepare(op operation) {
   225  	sqe := l.ring.GetSQE()
   226  	if sqe == nil { // submit and retry
   227  		l.submit()
   228  		sqe = l.ring.GetSQE()
   229  	}
   230  	if sqe == nil { // still nothing, add to pending
   231  		l.pending = append(l.pending, op)
   232  		return
   233  	}
   234  	op(sqe)
   235  }
   236  
   237  func (l *Loop) prepareMultishotAccept(fd int, cb completionCallback) {
   238  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   239  		sqe.PrepareMultishotAccept(fd, 0, 0, 0)
   240  		l.callbacks.set(sqe, cb)
   241  	})
   242  }
   243  
   244  func (l *Loop) prepareCancelFd(fd int, cb completionCallback) {
   245  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   246  		sqe.PrepareCancelFd(fd, 0)
   247  		l.callbacks.set(sqe, cb)
   248  	})
   249  }
   250  
   251  func (l *Loop) prepareShutdown(fd int, cb completionCallback) {
   252  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   253  		const SHUT_RDWR = 2
   254  		sqe.PrepareShutdown(fd, SHUT_RDWR)
   255  		l.callbacks.set(sqe, cb)
   256  	})
   257  }
   258  
   259  func (l *Loop) prepareClose(fd int, cb completionCallback) {
   260  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   261  		sqe.PrepareClose(fd)
   262  		l.callbacks.set(sqe, cb)
   263  	})
   264  }
   265  
   266  // assumes that buf is already pinned in the caller
   267  func (l *Loop) prepareSend(fd int, buf []byte, cb completionCallback) {
   268  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   269  		sqe.PrepareSend(fd, uintptr(unsafe.Pointer(&buf[0])), uint32(len(buf)), 0)
   270  		l.callbacks.set(sqe, cb)
   271  	})
   272  }
   273  
   274  // references from std lib:
   275  // https://github.com/golang/go/blob/140266fe7521bf75bf0037f12265190213cc8e7d/src/internal/poll/writev.go#L16
   276  // https://github.com/golang/go/blob/140266fe7521bf75bf0037f12265190213cc8e7d/src/internal/poll/fd_writev_unix.go#L20
   277  // assumes that iovecs are pinner in caller
   278  func (l *Loop) prepareWritev(fd int, iovecs []syscall.Iovec, cb completionCallback) {
   279  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   280  		sqe.PrepareWritev(fd, uintptr(unsafe.Pointer(&iovecs[0])), uint32(len(iovecs)), 0)
   281  		l.callbacks.set(sqe, cb)
   282  	})
   283  }
   284  
   285  // Multishot, provided buffers recv
   286  func (l *Loop) prepareRecv(fd int, cb completionCallback) {
   287  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   288  		sqe.PrepareRecvMultishot(fd, 0, 0, 0)
   289  		sqe.Flags = giouring.SqeBufferSelect
   290  		sqe.BufIG = buffersGroupID
   291  		l.callbacks.set(sqe, cb)
   292  	})
   293  }
   294  
   295  func (l *Loop) prepareConnect(fd int, addr uintptr, addrLen uint64, cb completionCallback) {
   296  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   297  		sqe.PrepareConnect(fd, addr, addrLen)
   298  		l.callbacks.set(sqe, cb)
   299  	})
   300  }
   301  
   302  func (l *Loop) prepareStreamSocket(domain int, cb completionCallback) {
   303  	l.prepare(func(sqe *giouring.SubmissionQueueEntry) {
   304  		sqe.PrepareSocket(domain, syscall.SOCK_STREAM, 0, 0)
   305  		l.callbacks.set(sqe, cb)
   306  	})
   307  }
   308  
   309  func cqeErr(c *giouring.CompletionQueueEvent) *ErrErrno {
   310  	if c.Res > -4096 && c.Res < 0 {
   311  		errno := syscall.Errno(-c.Res)
   312  		return &ErrErrno{Errno: errno}
   313  	}
   314  	return nil
   315  }
   316  
   317  type ErrErrno struct {
   318  	Errno syscall.Errno
   319  }
   320  
   321  func (e *ErrErrno) Error() string {
   322  	return e.Errno.Error()
   323  }
   324  
   325  func (e *ErrErrno) Temporary() bool {
   326  	o := e.Errno
   327  	return o == syscall.EINTR || o == syscall.EMFILE || o == syscall.ENFILE ||
   328  		o == syscall.ENOBUFS || e.Timeout()
   329  }
   330  
   331  func (e *ErrErrno) Timeout() bool {
   332  	o := e.Errno
   333  	return o == syscall.EAGAIN || o == syscall.EWOULDBLOCK || o == syscall.ETIMEDOUT ||
   334  		o == syscall.ETIME
   335  }
   336  
   337  func (e *ErrErrno) Canceled() bool {
   338  	return e.Errno == syscall.ECANCELED
   339  }
   340  
   341  func (e *ErrErrno) ConnectionReset() bool {
   342  	return e.Errno == syscall.ECONNRESET || e.Errno == syscall.ENOTCONN
   343  }
   344  
   345  // #region providedBuffers
   346  
   347  type providedBuffers struct {
   348  	br      *giouring.BufAndRing
   349  	data    []byte
   350  	entries uint32
   351  	bufLen  uint32
   352  }
   353  
   354  func (b *providedBuffers) init(ring *giouring.Ring, entries uint32, bufLen uint32) error {
   355  	b.entries = entries
   356  	b.bufLen = bufLen
   357  	// mmap allocated space for all buffers
   358  	var err error
   359  	size := int(b.entries * b.bufLen)
   360  	b.data, err = syscall.Mmap(-1, 0, size,
   361  		syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANON|syscall.MAP_PRIVATE)
   362  	if err != nil {
   363  		return err
   364  	}
   365  	// share buffers with io_uring
   366  	b.br, err = ring.SetupBufRing(b.entries, buffersGroupID, 0)
   367  	if err != nil {
   368  		return err
   369  	}
   370  	for i := uint32(0); i < b.entries; i++ {
   371  		b.br.BufRingAdd(
   372  			uintptr(unsafe.Pointer(&b.data[b.bufLen*i])),
   373  			b.bufLen,
   374  			uint16(i),
   375  			giouring.BufRingMask(b.entries),
   376  			int(i),
   377  		)
   378  	}
   379  	b.br.BufRingAdvance(int(b.entries))
   380  	return nil
   381  }
   382  
   383  // get provided buffer from cqe res, flags
   384  func (b *providedBuffers) get(res int32, flags uint32) ([]byte, uint16) {
   385  	isProvidedBuffer := flags&giouring.CQEFBuffer > 0
   386  	if !isProvidedBuffer {
   387  		panic("missing buffer flag")
   388  	}
   389  	bufferID := uint16(flags >> giouring.CQEBufferShift)
   390  	start := uint32(bufferID) * b.bufLen
   391  	n := uint32(res)
   392  	return b.data[start : start+n], bufferID
   393  }
   394  
   395  // return provided buffer to the kernel
   396  func (b *providedBuffers) release(buf []byte, bufferID uint16) {
   397  	b.br.BufRingAdd(
   398  		uintptr(unsafe.Pointer(&buf[0])),
   399  		b.bufLen,
   400  		uint16(bufferID),
   401  		giouring.BufRingMask(b.entries),
   402  		0,
   403  	)
   404  	b.br.BufRingAdvance(1)
   405  }
   406  
   407  func (b *providedBuffers) deinit() {
   408  	_ = syscall.Munmap(b.data)
   409  }
   410  
   411  //#endregion providedBuffers
   412  
   413  // #region callbacks
   414  
   415  type callbacks struct {
   416  	m    map[uint64]completionCallback
   417  	next uint64
   418  }
   419  
   420  func (c *callbacks) init() {
   421  	c.m = make(map[uint64]completionCallback)
   422  	c.next = math.MaxUint16 // reserve first few userdata values for internal use
   423  }
   424  
   425  func (c *callbacks) set(sqe *giouring.SubmissionQueueEntry, cb completionCallback) {
   426  	c.next++
   427  	key := c.next
   428  	c.m[key] = cb
   429  	sqe.UserData = key
   430  }
   431  
   432  func (c *callbacks) get(cqe *giouring.CompletionQueueEvent) completionCallback {
   433  	ms := isMultiShot(cqe.Flags)
   434  	cb := c.m[cqe.UserData]
   435  	if !ms {
   436  		delete(c.m, cqe.UserData)
   437  	}
   438  	return cb
   439  }
   440  
   441  func (c *callbacks) count() int {
   442  	return len(c.m)
   443  }
   444  
   445  // #endregion
   446  
   447  func isMultiShot(flags uint32) bool {
   448  	return flags&giouring.CQEFMore > 0
   449  }
   450  
   451  // callback fired when tcp connection is dialed
   452  type Dialed func(fd int, tcpConn *TCPConn, err error)
   453  
   454  func (l *Loop) Dial(addr string, dialed Dialed) error {
   455  	sa, domain, err := resolveTCPAddr(addr)
   456  	if err != nil {
   457  		return err
   458  	}
   459  	rawAddr, rawAddrLen, err := sockaddr(sa)
   460  	if err != nil {
   461  		return err
   462  	}
   463  	var pinner runtime.Pinner
   464  	pinner.Pin(rawAddr)
   465  	l.prepareStreamSocket(domain, func(res int32, flags uint32, err *ErrErrno) {
   466  		if err != nil {
   467  			dialed(0, nil, err)
   468  			pinner.Unpin()
   469  			return
   470  		}
   471  		fd := int(res)
   472  		l.prepareConnect(fd, uintptr(rawAddr), uint64(rawAddrLen), func(res int32, flags uint32, err *ErrErrno) {
   473  			defer pinner.Unpin()
   474  			if err != nil {
   475  				dialed(0, nil, err)
   476  				return
   477  			}
   478  			conn := newTcpConn(l, func() { delete(l.connections, fd) }, fd)
   479  			l.connections[fd] = conn
   480  			dialed(fd, conn, nil)
   481  		})
   482  	})
   483  	return nil
   484  }
   485  
   486  // callback fired when new connection is accepted by listener
   487  type Accepted func(fd int, tcpConn *TCPConn)
   488  
   489  // ip4:  "127.0.0.1:8080",
   490  // ip6: "[::1]:80"
   491  func (l *Loop) Listen(addr string, accepted Accepted) (*TCPListener, error) {
   492  	sa, domain, err := resolveTCPAddr(addr)
   493  	if err != nil {
   494  		return nil, err
   495  	}
   496  	fd, port, err := listen(sa, domain)
   497  	if err != nil {
   498  		return nil, err
   499  	}
   500  	ln := &TCPListener{
   501  		fd:          fd,
   502  		port:        port,
   503  		loop:        l,
   504  		accepted:    accepted,
   505  		connections: make(map[int]*TCPConn),
   506  	}
   507  	l.listeners[fd] = ln
   508  	ln.accept()
   509  	return ln, nil
   510  }