github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/conn/mmsg.go (about)

     1  //go:build linux || netbsd
     2  
     3  package conn
     4  
     5  import (
     6  	"net"
     7  	"os"
     8  	"syscall"
     9  )
    10  
    11  type rawUDPConn struct {
    12  	*net.UDPConn
    13  	rawConn syscall.RawConn
    14  }
    15  
    16  // NewRawUDPConn wraps a [net.UDPConn] in a [rawUDPConn] for batch I/O.
    17  func NewRawUDPConn(udpConn *net.UDPConn) (rawUDPConn, error) {
    18  	rawConn, err := udpConn.SyscallConn()
    19  	if err != nil {
    20  		return rawUDPConn{}, err
    21  	}
    22  
    23  	return rawUDPConn{
    24  		UDPConn: udpConn,
    25  		rawConn: rawConn,
    26  	}, nil
    27  }
    28  
    29  // MmsgRConn wraps a [net.UDPConn] and provides the [ReadMsgs] method
    30  // for reading multiple messages in a single recvmmsg(2) system call.
    31  //
    32  // [MmsgRConn] is not safe for concurrent use.
    33  // Use the [RConn] method to create a new [MmsgRConn] instance for each goroutine.
    34  type MmsgRConn struct {
    35  	rawUDPConn
    36  	rawReadFunc func(fd uintptr) (done bool)
    37  	readMsgvec  []Mmsghdr
    38  	readFlags   int
    39  	readN       int
    40  	readErr     error
    41  }
    42  
    43  // MmsgWConn wraps a [net.UDPConn] and provides the [WriteMsgs] method
    44  // for writing multiple messages in a single sendmmsg(2) system call.
    45  //
    46  // [MmsgWConn] is not safe for concurrent use.
    47  // Use the [WConn] method to create a new [MmsgWConn] instance for each goroutine.
    48  type MmsgWConn struct {
    49  	rawUDPConn
    50  	rawWriteFunc func(fd uintptr) (done bool)
    51  	writeMsgvec  []Mmsghdr
    52  	writeFlags   int
    53  	writeErr     error
    54  }
    55  
    56  // RConn returns a new [MmsgRConn] instance for batch reading.
    57  func (c rawUDPConn) RConn() *MmsgRConn {
    58  	mmsgRConn := MmsgRConn{
    59  		rawUDPConn: c,
    60  	}
    61  
    62  	mmsgRConn.rawReadFunc = func(fd uintptr) (done bool) {
    63  		var errno syscall.Errno
    64  		mmsgRConn.readN, errno = recvmmsg(int(fd), mmsgRConn.readMsgvec, mmsgRConn.readFlags)
    65  		switch errno {
    66  		case 0:
    67  		case syscall.EAGAIN:
    68  			return false
    69  		default:
    70  			mmsgRConn.readErr = os.NewSyscallError("recvmmsg", errno)
    71  		}
    72  		return true
    73  	}
    74  
    75  	return &mmsgRConn
    76  }
    77  
    78  // WConn returns a new [MmsgWConn] instance for batch writing.
    79  func (c rawUDPConn) WConn() *MmsgWConn {
    80  	mmsgWConn := MmsgWConn{
    81  		rawUDPConn: c,
    82  	}
    83  
    84  	mmsgWConn.rawWriteFunc = func(fd uintptr) (done bool) {
    85  		n, errno := sendmmsg(int(fd), mmsgWConn.writeMsgvec, mmsgWConn.writeFlags)
    86  		switch errno {
    87  		case 0:
    88  		case syscall.EAGAIN:
    89  			return false
    90  		default:
    91  			mmsgWConn.writeErr = os.NewSyscallError("sendmmsg", errno)
    92  			n = 1
    93  		}
    94  		mmsgWConn.writeMsgvec = mmsgWConn.writeMsgvec[n:]
    95  		// According to tokio, not writing the full msgvec is sufficient to show
    96  		// that the socket buffer is full. Previous tests also showed that this is
    97  		// faster than immediately trying to write again.
    98  		//
    99  		// Do keep in mind that this is not how the Go runtime handles writes though.
   100  		return len(mmsgWConn.writeMsgvec) == 0
   101  	}
   102  
   103  	return &mmsgWConn
   104  }
   105  
   106  // ReadMsgs reads as many messages as possible into the given msgvec
   107  // and returns the number of messages read or an error.
   108  func (c *MmsgRConn) ReadMsgs(msgvec []Mmsghdr, flags int) (int, error) {
   109  	c.readMsgvec = msgvec
   110  	c.readFlags = flags
   111  	c.readN = 0
   112  	c.readErr = nil
   113  	if err := c.rawConn.Read(c.rawReadFunc); err != nil {
   114  		return 0, err
   115  	}
   116  	return c.readN, c.readErr
   117  }
   118  
   119  // WriteMsgs writes all messages in the given msgvec and returns the last encountered error.
   120  func (c *MmsgWConn) WriteMsgs(msgvec []Mmsghdr, flags int) error {
   121  	c.writeMsgvec = msgvec
   122  	c.writeFlags = flags
   123  	c.writeErr = nil
   124  	if err := c.rawConn.Write(c.rawWriteFunc); err != nil {
   125  		return err
   126  	}
   127  	return c.writeErr
   128  }