github.com/ianic/xnet/aio@v0.0.0-20230924160527-cee7f41ab201/loop.go (about) 1 package aio 2 3 import ( 4 "context" 5 "log/slog" 6 "math" 7 "os" 8 "runtime" 9 "syscall" 10 "time" 11 "unsafe" 12 13 "github.com/pawelgaczynski/giouring" 14 ) 15 16 const ( 17 batchSize = 128 18 buffersGroupID = 0 // currently using only 1 provided buffer group 19 ) 20 21 type completionCallback = func(res int32, flags uint32, err *ErrErrno) 22 type operation = func(*giouring.SubmissionQueueEntry) 23 24 type Loop struct { 25 ring *giouring.Ring 26 callbacks callbacks 27 buffers providedBuffers 28 pending []operation 29 30 listeners map[int]*TCPListener 31 connections map[int]*TCPConn 32 } 33 34 type Options struct { 35 RingEntries uint32 36 RecvBuffersCount uint32 37 RecvBufferLen uint32 38 } 39 40 var DefaultOptions = Options{ 41 RingEntries: 1024, 42 RecvBuffersCount: 256, 43 RecvBufferLen: 4 * 1024, 44 } 45 46 func New(opt Options) (*Loop, error) { 47 ring, err := giouring.CreateRing(opt.RingEntries) 48 if err != nil { 49 return nil, err 50 } 51 l := &Loop{ 52 ring: ring, 53 listeners: make(map[int]*TCPListener), 54 connections: make(map[int]*TCPConn), 55 } 56 l.callbacks.init() 57 if err := l.buffers.init(ring, opt.RecvBuffersCount, opt.RecvBufferLen); err != nil { 58 return nil, err 59 } 60 return l, nil 61 } 62 63 // runOnce performs one loop run. 64 // Submits all prepared operations to the kernel and waits for at least one 65 // completed operation by the kernel. 66 func (l *Loop) runOnce() error { 67 if err := l.submitAndWait(1); err != nil { 68 return err 69 } 70 _ = l.flushCompletions() 71 return nil 72 } 73 74 // runUntilDone runs loop until all prepared operations are finished. 75 func (l *Loop) runUntilDone() error { 76 for { 77 if l.callbacks.count() == 0 { 78 if len(l.connections) > 0 || len(l.listeners) > 0 { 79 panic("unclean shutdown") 80 } 81 return nil 82 } 83 if err := l.runOnce(); err != nil { 84 return err 85 } 86 } 87 } 88 89 // Run runs loop until ctx is cancelled. Then performs clean shutdown. 90 // After ctx is done it closes all pending listeners and dialed connections. 91 // Listener will first stop listening then close all accepted connections. 92 // Loop will wait for all operations to finish. 93 func (l *Loop) Run(ctx context.Context) error { 94 // run until ctx is done 95 if err := l.runCtx(ctx, time.Millisecond*333); err != nil { 96 return err 97 } 98 l.closePendingConnections() 99 // run loop until all operations finishes 100 if err := l.runUntilDone(); err != nil { 101 return err 102 } 103 return nil 104 } 105 106 func (l *Loop) closePendingConnections() { 107 for _, lsn := range l.listeners { 108 lsn.Close() 109 } 110 for _, conn := range l.connections { 111 conn.Close() 112 } 113 } 114 115 // runCtx runs loop until context is canceled. 116 // Checks context every `timeout`. 117 func (l *Loop) runCtx(ctx context.Context, timeout time.Duration) error { 118 ts := syscall.NsecToTimespec(int64(timeout)) 119 done := func() bool { 120 select { 121 case <-ctx.Done(): 122 return true 123 default: 124 } 125 return false 126 } 127 for { 128 if err := l.submit(); err != nil { 129 return err 130 } 131 if _, err := l.ring.WaitCQEs(1, &ts, nil); err != nil && !TemporaryError(err) { 132 return err 133 } 134 _ = l.flushCompletions() 135 if done() { 136 break 137 } 138 } 139 return nil 140 } 141 142 // TemporaryError returns true if syscall.Errno should be threated as temporary. 143 func TemporaryError(err error) bool { 144 if errno, ok := err.(syscall.Errno); ok { 145 return (&ErrErrno{Errno: errno}).Temporary() 146 } 147 if os.IsTimeout(err) { 148 return true 149 } 150 return false 151 } 152 153 // Retries on temporary errors. 154 // Anything not handled here is fatal and application should terminate. 155 // Errors that can be returned by [io_uring_enter]. 156 // 157 // [io_uring_enter]: https://manpages.debian.org/unstable/liburing-dev/io_uring_enter.2.en.html#ERRORS 158 func (l *Loop) submitAndWait(waitNr uint32) error { 159 for { 160 if len(l.pending) > 0 { 161 _, err := l.ring.SubmitAndWait(0) 162 if err == nil { 163 l.preparePending() 164 } 165 } 166 167 _, err := l.ring.SubmitAndWait(waitNr) 168 if err != nil && TemporaryError(err) { 169 continue 170 } 171 return err 172 } 173 } 174 175 func (l *Loop) preparePending() { 176 prepared := 0 177 for _, op := range l.pending { 178 sqe := l.ring.GetSQE() 179 if sqe == nil { 180 break 181 } 182 op(sqe) 183 prepared++ 184 } 185 if prepared == len(l.pending) { 186 l.pending = nil 187 } else { 188 l.pending = l.pending[prepared:] 189 } 190 } 191 192 func (l *Loop) submit() error { 193 return l.submitAndWait(0) 194 } 195 196 func (l *Loop) flushCompletions() uint32 { 197 var cqes [batchSize]*giouring.CompletionQueueEvent 198 var noCompleted uint32 = 0 199 for { 200 peeked := l.ring.PeekBatchCQE(cqes[:]) 201 for _, cqe := range cqes[:peeked] { 202 err := cqeErr(cqe) 203 if cqe.UserData == 0 { 204 slog.Debug("ceq without userdata", "res", cqe.Res, "flags", cqe.Flags, "err", err) 205 continue 206 } 207 cb := l.callbacks.get(cqe) 208 cb(cqe.Res, cqe.Flags, err) 209 } 210 l.ring.CQAdvance(peeked) 211 noCompleted += peeked 212 if peeked < uint32(len(cqes)) { 213 return noCompleted 214 } 215 } 216 } 217 218 func (l *Loop) Close() { 219 l.ring.QueueExit() 220 l.buffers.deinit() 221 } 222 223 // prepares operation or adds it to pending if can't get sqe 224 func (l *Loop) prepare(op operation) { 225 sqe := l.ring.GetSQE() 226 if sqe == nil { // submit and retry 227 l.submit() 228 sqe = l.ring.GetSQE() 229 } 230 if sqe == nil { // still nothing, add to pending 231 l.pending = append(l.pending, op) 232 return 233 } 234 op(sqe) 235 } 236 237 func (l *Loop) prepareMultishotAccept(fd int, cb completionCallback) { 238 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 239 sqe.PrepareMultishotAccept(fd, 0, 0, 0) 240 l.callbacks.set(sqe, cb) 241 }) 242 } 243 244 func (l *Loop) prepareCancelFd(fd int, cb completionCallback) { 245 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 246 sqe.PrepareCancelFd(fd, 0) 247 l.callbacks.set(sqe, cb) 248 }) 249 } 250 251 func (l *Loop) prepareShutdown(fd int, cb completionCallback) { 252 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 253 const SHUT_RDWR = 2 254 sqe.PrepareShutdown(fd, SHUT_RDWR) 255 l.callbacks.set(sqe, cb) 256 }) 257 } 258 259 func (l *Loop) prepareClose(fd int, cb completionCallback) { 260 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 261 sqe.PrepareClose(fd) 262 l.callbacks.set(sqe, cb) 263 }) 264 } 265 266 // assumes that buf is already pinned in the caller 267 func (l *Loop) prepareSend(fd int, buf []byte, cb completionCallback) { 268 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 269 sqe.PrepareSend(fd, uintptr(unsafe.Pointer(&buf[0])), uint32(len(buf)), 0) 270 l.callbacks.set(sqe, cb) 271 }) 272 } 273 274 // references from std lib: 275 // https://github.com/golang/go/blob/140266fe7521bf75bf0037f12265190213cc8e7d/src/internal/poll/writev.go#L16 276 // https://github.com/golang/go/blob/140266fe7521bf75bf0037f12265190213cc8e7d/src/internal/poll/fd_writev_unix.go#L20 277 // assumes that iovecs are pinner in caller 278 func (l *Loop) prepareWritev(fd int, iovecs []syscall.Iovec, cb completionCallback) { 279 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 280 sqe.PrepareWritev(fd, uintptr(unsafe.Pointer(&iovecs[0])), uint32(len(iovecs)), 0) 281 l.callbacks.set(sqe, cb) 282 }) 283 } 284 285 // Multishot, provided buffers recv 286 func (l *Loop) prepareRecv(fd int, cb completionCallback) { 287 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 288 sqe.PrepareRecvMultishot(fd, 0, 0, 0) 289 sqe.Flags = giouring.SqeBufferSelect 290 sqe.BufIG = buffersGroupID 291 l.callbacks.set(sqe, cb) 292 }) 293 } 294 295 func (l *Loop) prepareConnect(fd int, addr uintptr, addrLen uint64, cb completionCallback) { 296 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 297 sqe.PrepareConnect(fd, addr, addrLen) 298 l.callbacks.set(sqe, cb) 299 }) 300 } 301 302 func (l *Loop) prepareStreamSocket(domain int, cb completionCallback) { 303 l.prepare(func(sqe *giouring.SubmissionQueueEntry) { 304 sqe.PrepareSocket(domain, syscall.SOCK_STREAM, 0, 0) 305 l.callbacks.set(sqe, cb) 306 }) 307 } 308 309 func cqeErr(c *giouring.CompletionQueueEvent) *ErrErrno { 310 if c.Res > -4096 && c.Res < 0 { 311 errno := syscall.Errno(-c.Res) 312 return &ErrErrno{Errno: errno} 313 } 314 return nil 315 } 316 317 type ErrErrno struct { 318 Errno syscall.Errno 319 } 320 321 func (e *ErrErrno) Error() string { 322 return e.Errno.Error() 323 } 324 325 func (e *ErrErrno) Temporary() bool { 326 o := e.Errno 327 return o == syscall.EINTR || o == syscall.EMFILE || o == syscall.ENFILE || 328 o == syscall.ENOBUFS || e.Timeout() 329 } 330 331 func (e *ErrErrno) Timeout() bool { 332 o := e.Errno 333 return o == syscall.EAGAIN || o == syscall.EWOULDBLOCK || o == syscall.ETIMEDOUT || 334 o == syscall.ETIME 335 } 336 337 func (e *ErrErrno) Canceled() bool { 338 return e.Errno == syscall.ECANCELED 339 } 340 341 func (e *ErrErrno) ConnectionReset() bool { 342 return e.Errno == syscall.ECONNRESET || e.Errno == syscall.ENOTCONN 343 } 344 345 // #region providedBuffers 346 347 type providedBuffers struct { 348 br *giouring.BufAndRing 349 data []byte 350 entries uint32 351 bufLen uint32 352 } 353 354 func (b *providedBuffers) init(ring *giouring.Ring, entries uint32, bufLen uint32) error { 355 b.entries = entries 356 b.bufLen = bufLen 357 // mmap allocated space for all buffers 358 var err error 359 size := int(b.entries * b.bufLen) 360 b.data, err = syscall.Mmap(-1, 0, size, 361 syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANON|syscall.MAP_PRIVATE) 362 if err != nil { 363 return err 364 } 365 // share buffers with io_uring 366 b.br, err = ring.SetupBufRing(b.entries, buffersGroupID, 0) 367 if err != nil { 368 return err 369 } 370 for i := uint32(0); i < b.entries; i++ { 371 b.br.BufRingAdd( 372 uintptr(unsafe.Pointer(&b.data[b.bufLen*i])), 373 b.bufLen, 374 uint16(i), 375 giouring.BufRingMask(b.entries), 376 int(i), 377 ) 378 } 379 b.br.BufRingAdvance(int(b.entries)) 380 return nil 381 } 382 383 // get provided buffer from cqe res, flags 384 func (b *providedBuffers) get(res int32, flags uint32) ([]byte, uint16) { 385 isProvidedBuffer := flags&giouring.CQEFBuffer > 0 386 if !isProvidedBuffer { 387 panic("missing buffer flag") 388 } 389 bufferID := uint16(flags >> giouring.CQEBufferShift) 390 start := uint32(bufferID) * b.bufLen 391 n := uint32(res) 392 return b.data[start : start+n], bufferID 393 } 394 395 // return provided buffer to the kernel 396 func (b *providedBuffers) release(buf []byte, bufferID uint16) { 397 b.br.BufRingAdd( 398 uintptr(unsafe.Pointer(&buf[0])), 399 b.bufLen, 400 uint16(bufferID), 401 giouring.BufRingMask(b.entries), 402 0, 403 ) 404 b.br.BufRingAdvance(1) 405 } 406 407 func (b *providedBuffers) deinit() { 408 _ = syscall.Munmap(b.data) 409 } 410 411 //#endregion providedBuffers 412 413 // #region callbacks 414 415 type callbacks struct { 416 m map[uint64]completionCallback 417 next uint64 418 } 419 420 func (c *callbacks) init() { 421 c.m = make(map[uint64]completionCallback) 422 c.next = math.MaxUint16 // reserve first few userdata values for internal use 423 } 424 425 func (c *callbacks) set(sqe *giouring.SubmissionQueueEntry, cb completionCallback) { 426 c.next++ 427 key := c.next 428 c.m[key] = cb 429 sqe.UserData = key 430 } 431 432 func (c *callbacks) get(cqe *giouring.CompletionQueueEvent) completionCallback { 433 ms := isMultiShot(cqe.Flags) 434 cb := c.m[cqe.UserData] 435 if !ms { 436 delete(c.m, cqe.UserData) 437 } 438 return cb 439 } 440 441 func (c *callbacks) count() int { 442 return len(c.m) 443 } 444 445 // #endregion 446 447 func isMultiShot(flags uint32) bool { 448 return flags&giouring.CQEFMore > 0 449 } 450 451 // callback fired when tcp connection is dialed 452 type Dialed func(fd int, tcpConn *TCPConn, err error) 453 454 func (l *Loop) Dial(addr string, dialed Dialed) error { 455 sa, domain, err := resolveTCPAddr(addr) 456 if err != nil { 457 return err 458 } 459 rawAddr, rawAddrLen, err := sockaddr(sa) 460 if err != nil { 461 return err 462 } 463 var pinner runtime.Pinner 464 pinner.Pin(rawAddr) 465 l.prepareStreamSocket(domain, func(res int32, flags uint32, err *ErrErrno) { 466 if err != nil { 467 dialed(0, nil, err) 468 pinner.Unpin() 469 return 470 } 471 fd := int(res) 472 l.prepareConnect(fd, uintptr(rawAddr), uint64(rawAddrLen), func(res int32, flags uint32, err *ErrErrno) { 473 defer pinner.Unpin() 474 if err != nil { 475 dialed(0, nil, err) 476 return 477 } 478 conn := newTcpConn(l, func() { delete(l.connections, fd) }, fd) 479 l.connections[fd] = conn 480 dialed(fd, conn, nil) 481 }) 482 }) 483 return nil 484 } 485 486 // callback fired when new connection is accepted by listener 487 type Accepted func(fd int, tcpConn *TCPConn) 488 489 // ip4: "127.0.0.1:8080", 490 // ip6: "[::1]:80" 491 func (l *Loop) Listen(addr string, accepted Accepted) (*TCPListener, error) { 492 sa, domain, err := resolveTCPAddr(addr) 493 if err != nil { 494 return nil, err 495 } 496 fd, port, err := listen(sa, domain) 497 if err != nil { 498 return nil, err 499 } 500 ln := &TCPListener{ 501 fd: fd, 502 port: port, 503 loop: l, 504 accepted: accepted, 505 connections: make(map[int]*TCPConn), 506 } 507 l.listeners[fd] = ln 508 ln.accept() 509 return ln, nil 510 }