github.com/ianic/xnet/aio@v0.0.0-20230924160527-cee7f41ab201/tcp_conn.go (about)

     1  package aio
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"log/slog"
     7  	"runtime"
     8  	"syscall"
     9  )
    10  
    11  var (
    12  	ErrListenerClose = errors.New("listener closed connection")
    13  	ErrUpstreamClose = errors.New("upstream closed connection")
    14  )
    15  
    16  // upper layer's events handler interface
    17  type Upstream interface {
    18  	Received([]byte)
    19  	Closed(error)
    20  	Sent()
    21  }
    22  
    23  type TCPConn struct {
    24  	closedCallback func()
    25  	loop           *Loop
    26  	fd             int
    27  	up             Upstream
    28  	shutdownError  error
    29  }
    30  
    31  func newTcpConn(loop *Loop, closedCallback func(), fd int) *TCPConn {
    32  	return &TCPConn{loop: loop, fd: fd, closedCallback: closedCallback}
    33  }
    34  
    35  // Bind connects this connection and upstream handler. It's up to the
    36  // upstream handler to call bind when ready. On in any other time when it need
    37  // to change upstream layer. For example during after websocket handshake layer
    38  // can be change from one which were handling handshake to one which will handle
    39  // websocket frames.
    40  func (tc *TCPConn) Bind(up Upstream) {
    41  	startRecv := tc.up == nil
    42  	tc.up = up
    43  	if startRecv {
    44  		tc.recvLoop()
    45  	}
    46  }
    47  
    48  // TODO: add correlation id (userdata) for send/sent connecting
    49  func (tc *TCPConn) Send(data []byte) {
    50  	nn := 0 // number of bytes sent
    51  	var cb completionCallback
    52  	var pinner runtime.Pinner
    53  	pinner.Pin(&data[0])
    54  	cb = func(res int32, flags uint32, err *ErrErrno) {
    55  		nn += int(res) // bytes written so far
    56  		if err != nil {
    57  			pinner.Unpin()
    58  			tc.shutdown(err)
    59  			return
    60  		}
    61  		if nn >= len(data) {
    62  			pinner.Unpin()
    63  			tc.up.Sent() // all sent call callback
    64  			return
    65  		}
    66  		// send rest of the data
    67  		tc.loop.prepareSend(tc.fd, data[nn:], cb)
    68  	}
    69  	tc.loop.prepareSend(tc.fd, data, cb)
    70  }
    71  
    72  func (tc *TCPConn) SendBuffers(buffers [][]byte) {
    73  	var cb completionCallback
    74  	var pinner runtime.Pinner
    75  	for _, buf := range buffers {
    76  		pinner.Pin(&buf[0])
    77  	}
    78  	cb = func(res int32, flags uint32, err *ErrErrno) {
    79  		n := int(res)
    80  		if err != nil {
    81  			pinner.Unpin()
    82  			tc.shutdown(err)
    83  			return
    84  		}
    85  		consumeBuffers(&buffers, n)
    86  		if len(buffers) == 0 {
    87  			pinner.Unpin()
    88  			tc.up.Sent()
    89  			return
    90  		}
    91  		// send rest of the data
    92  		iovecs := buffersToIovec(buffers)
    93  		pinner.Pin(&iovecs[0])
    94  		tc.loop.prepareWritev(tc.fd, iovecs, cb)
    95  	}
    96  	iovecs := buffersToIovec(buffers)
    97  	pinner.Pin(&iovecs[0])
    98  	tc.loop.prepareWritev(tc.fd, iovecs, cb)
    99  }
   100  
   101  func buffersToIovec(buffers [][]byte) []syscall.Iovec {
   102  	var iovecs []syscall.Iovec
   103  	for _, buf := range buffers {
   104  		if len(buf) == 0 {
   105  			continue
   106  		}
   107  		iovecs = append(iovecs, syscall.Iovec{Base: &buf[0]})
   108  		iovecs[len(iovecs)-1].SetLen(len(buf))
   109  	}
   110  	return iovecs
   111  }
   112  
   113  // consumeBuffers removes data from a slice of byte slices, for writev.
   114  // copied from:
   115  // https://github.com/golang/go/blob/140266fe7521bf75bf0037f12265190213cc8e7d/src/internal/poll/fd.go#L69
   116  func consumeBuffers(v *[][]byte, n int) {
   117  	for len(*v) > 0 {
   118  		ln0 := len((*v)[0])
   119  		if ln0 > n {
   120  			(*v)[0] = (*v)[0][n:]
   121  			return
   122  		}
   123  		n -= ln0
   124  		(*v)[0] = nil
   125  		*v = (*v)[1:]
   126  	}
   127  }
   128  
   129  func (tc *TCPConn) Close() {
   130  	tc.shutdown(ErrUpstreamClose)
   131  }
   132  
   133  // recvLoop starts multishot recv on fd
   134  // Will receive on fd until error occurs.
   135  func (tc *TCPConn) recvLoop() {
   136  	var cb completionCallback
   137  	cb = func(res int32, flags uint32, err *ErrErrno) {
   138  		if err != nil {
   139  			if err.Temporary() {
   140  				slog.Debug("tcp conn read temporary error", "error", err.Error())
   141  				tc.loop.prepareRecv(tc.fd, cb)
   142  				return
   143  			}
   144  			if !err.ConnectionReset() {
   145  				slog.Warn("tcp conn read error", "error", err.Error())
   146  			}
   147  			tc.shutdown(err)
   148  			return
   149  		}
   150  		if res == 0 {
   151  			tc.shutdown(io.EOF)
   152  			return
   153  		}
   154  		buf, id := tc.loop.buffers.get(res, flags)
   155  		tc.up.Received(buf)
   156  		tc.loop.buffers.release(buf, id)
   157  		if !isMultiShot(flags) {
   158  			slog.Debug("tcp conn multishot terminated", slog.Uint64("flags", uint64(flags)), slog.String("error", err.Error()))
   159  			// io_uring can terminate multishot recv when cqe is full
   160  			// need to restart it then
   161  			// ref: https://lore.kernel.org/lkml/20220630091231.1456789-3-dylany@fb.com/T/#re5daa4d5b6e4390ecf024315d9693e5d18d61f10
   162  			tc.loop.prepareRecv(tc.fd, cb)
   163  		}
   164  	}
   165  	tc.loop.prepareRecv(tc.fd, cb)
   166  }
   167  
   168  // shutdown tcp (both) then close fd
   169  func (tc *TCPConn) shutdown(err error) {
   170  	if err == nil {
   171  		panic("tcp conn missing shutdown reason")
   172  	}
   173  	if tc.shutdownError != nil {
   174  		return
   175  	}
   176  	tc.shutdownError = err
   177  	tc.loop.prepareShutdown(tc.fd, func(res int32, flags uint32, err *ErrErrno) {
   178  		if err != nil {
   179  			if !err.ConnectionReset() {
   180  				slog.Debug("tcp conn shutdown", "fd", tc.fd, "err", err, "res", res, "flags", flags)
   181  			}
   182  			if tc.closedCallback != nil {
   183  				tc.closedCallback()
   184  			}
   185  			tc.up.Closed(tc.shutdownError)
   186  			return
   187  		}
   188  		tc.loop.prepareClose(tc.fd, func(res int32, flags uint32, err *ErrErrno) {
   189  			if err != nil {
   190  				slog.Debug("tcp conn close", "fd", tc.fd, "errno", err, "res", res, "flags", flags)
   191  			}
   192  			if tc.closedCallback != nil {
   193  				tc.closedCallback()
   194  			}
   195  			tc.up.Closed(tc.shutdownError)
   196  		})
   197  	})
   198  }