github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/sys_conn_oob.go (about)

     1  //go:build darwin || linux || freebsd
     2  
     3  package quic
     4  
     5  import (
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/netip"
    11  	"syscall"
    12  	"time"
    13  
    14  	"golang.org/x/net/ipv4"
    15  	"golang.org/x/net/ipv6"
    16  	"golang.org/x/sys/unix"
    17  
    18  	"github.com/mikelsr/quic-go/internal/protocol"
    19  	"github.com/mikelsr/quic-go/internal/utils"
    20  )
    21  
    22  const (
    23  	ecnMask       = 0x3
    24  	oobBufferSize = 128
    25  )
    26  
    27  // Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version.
    28  // They're both just aliases for x/net/internal/socket.Message.
    29  // This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
    30  var _ ipv4.Message = ipv6.Message{}
    31  
    32  type batchConn interface {
    33  	ReadBatch(ms []ipv4.Message, flags int) (int, error)
    34  }
    35  
    36  func inspectReadBuffer(c syscall.RawConn) (int, error) {
    37  	var size int
    38  	var serr error
    39  	if err := c.Control(func(fd uintptr) {
    40  		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
    41  	}); err != nil {
    42  		return 0, err
    43  	}
    44  	return size, serr
    45  }
    46  
    47  func inspectWriteBuffer(c syscall.RawConn) (int, error) {
    48  	var size int
    49  	var serr error
    50  	if err := c.Control(func(fd uintptr) {
    51  		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
    52  	}); err != nil {
    53  		return 0, err
    54  	}
    55  	return size, serr
    56  }
    57  
    58  type oobConn struct {
    59  	OOBCapablePacketConn
    60  	batchConn batchConn
    61  
    62  	readPos uint8
    63  	// Packets received from the kernel, but not yet returned by ReadPacket().
    64  	messages []ipv4.Message
    65  	buffers  [batchSize]*packetBuffer
    66  
    67  	cap connCapabilities
    68  }
    69  
    70  var _ rawConn = &oobConn{}
    71  
    72  func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
    73  	rawConn, err := c.SyscallConn()
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  	needsPacketInfo := false
    78  	if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
    79  		needsPacketInfo = true
    80  	}
    81  	// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
    82  	// Try enabling receiving of ECN and packet info for both IP versions.
    83  	// We expect at least one of those syscalls to succeed.
    84  	var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
    85  	if err := rawConn.Control(func(fd uintptr) {
    86  		errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
    87  		errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
    88  
    89  		if needsPacketInfo {
    90  			errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4PKTINFO, 1)
    91  			errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
    92  		}
    93  	}); err != nil {
    94  		return nil, err
    95  	}
    96  	switch {
    97  	case errECNIPv4 == nil && errECNIPv6 == nil:
    98  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
    99  	case errECNIPv4 == nil && errECNIPv6 != nil:
   100  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
   101  	case errECNIPv4 != nil && errECNIPv6 == nil:
   102  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
   103  	case errECNIPv4 != nil && errECNIPv6 != nil:
   104  		return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
   105  	}
   106  	if needsPacketInfo {
   107  		switch {
   108  		case errPIIPv4 == nil && errPIIPv6 == nil:
   109  			utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
   110  		case errPIIPv4 == nil && errPIIPv6 != nil:
   111  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
   112  		case errPIIPv4 != nil && errPIIPv6 == nil:
   113  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
   114  		case errPIIPv4 != nil && errPIIPv6 != nil:
   115  			return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
   116  		}
   117  	}
   118  
   119  	// Allows callers to pass in a connection that already satisfies batchConn interface
   120  	// to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor
   121  	// via SyscallConn(), and read it that way, which might not be what the caller wants.
   122  	var bc batchConn
   123  	if ibc, ok := c.(batchConn); ok {
   124  		bc = ibc
   125  	} else {
   126  		bc = ipv4.NewPacketConn(c)
   127  	}
   128  
   129  	// Try enabling GSO.
   130  	// This will only succeed on Linux, and only for kernels > 4.18.
   131  	supportsGSO := maybeSetGSO(rawConn)
   132  
   133  	msgs := make([]ipv4.Message, batchSize)
   134  	for i := range msgs {
   135  		// preallocate the [][]byte
   136  		msgs[i].Buffers = make([][]byte, 1)
   137  	}
   138  	oobConn := &oobConn{
   139  		OOBCapablePacketConn: c,
   140  		batchConn:            bc,
   141  		messages:             msgs,
   142  		readPos:              batchSize,
   143  	}
   144  	oobConn.cap.DF = supportsDF
   145  	oobConn.cap.GSO = supportsGSO
   146  	for i := 0; i < batchSize; i++ {
   147  		oobConn.messages[i].OOB = make([]byte, oobBufferSize)
   148  	}
   149  	return oobConn, nil
   150  }
   151  
   152  func (c *oobConn) ReadPacket() (receivedPacket, error) {
   153  	if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
   154  		c.messages = c.messages[:batchSize]
   155  		// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
   156  		for i := uint8(0); i < c.readPos; i++ {
   157  			buffer := getPacketBuffer()
   158  			buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
   159  			c.buffers[i] = buffer
   160  			c.messages[i].Buffers[0] = c.buffers[i].Data
   161  		}
   162  		c.readPos = 0
   163  
   164  		n, err := c.batchConn.ReadBatch(c.messages, 0)
   165  		if n == 0 || err != nil {
   166  			return receivedPacket{}, err
   167  		}
   168  		c.messages = c.messages[:n]
   169  	}
   170  
   171  	msg := c.messages[c.readPos]
   172  	buffer := c.buffers[c.readPos]
   173  	c.readPos++
   174  
   175  	data := msg.OOB[:msg.NN]
   176  	p := receivedPacket{
   177  		remoteAddr: msg.Addr,
   178  		rcvTime:    time.Now(),
   179  		data:       msg.Buffers[0][:msg.N],
   180  		buffer:     buffer,
   181  	}
   182  	for len(data) > 0 {
   183  		hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
   184  		if err != nil {
   185  			return receivedPacket{}, err
   186  		}
   187  		if hdr.Level == unix.IPPROTO_IP {
   188  			switch hdr.Type {
   189  			case msgTypeIPTOS:
   190  				p.ecn = protocol.ECN(body[0] & ecnMask)
   191  			case ipv4PKTINFO:
   192  				// struct in_pktinfo {
   193  				// 	unsigned int   ipi_ifindex;  /* Interface index */
   194  				// 	struct in_addr ipi_spec_dst; /* Local address */
   195  				// 	struct in_addr ipi_addr;     /* Header Destination
   196  				// 									address */
   197  				// };
   198  				var ip [4]byte
   199  				if len(body) == 12 {
   200  					copy(ip[:], body[8:12])
   201  					p.info.ifIndex = binary.LittleEndian.Uint32(body)
   202  				} else if len(body) == 4 {
   203  					// FreeBSD
   204  					copy(ip[:], body)
   205  				}
   206  				p.info.addr = netip.AddrFrom4(ip)
   207  			}
   208  		}
   209  		if hdr.Level == unix.IPPROTO_IPV6 {
   210  			switch hdr.Type {
   211  			case unix.IPV6_TCLASS:
   212  				p.ecn = protocol.ECN(body[0] & ecnMask)
   213  			case unix.IPV6_PKTINFO:
   214  				// struct in6_pktinfo {
   215  				// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   216  				// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   217  				// };
   218  				if len(body) == 20 {
   219  					var ip [16]byte
   220  					copy(ip[:], body[:16])
   221  					p.info.addr = netip.AddrFrom16(ip)
   222  					p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
   223  				}
   224  			}
   225  		}
   226  		data = remainder
   227  	}
   228  	return p, nil
   229  }
   230  
   231  // WriteTo (re)implements the net.PacketConn method.
   232  // This is needed for users who call OptimizeConn to be able to send (non-QUIC) packets on the underlying connection.
   233  // With GSO enabled, this would otherwise not be needed, as the kernel requires the UDP_SEGMENT message to be set.
   234  func (c *oobConn) WriteTo(p []byte, addr net.Addr) (int, error) {
   235  	return c.WritePacket(p, uint16(len(p)), addr, nil)
   236  }
   237  
   238  // WritePacket writes a new packet.
   239  // If the connection supports GSO (and we activated GSO support before),
   240  // it appends the UDP_SEGMENT size message to oob.
   241  // Callers are advised to make sure that oob has a sufficient capacity,
   242  // such that appending the UDP_SEGMENT size message doesn't cause an allocation.
   243  func (c *oobConn) WritePacket(b []byte, packetSize uint16, addr net.Addr, oob []byte) (n int, err error) {
   244  	if c.cap.GSO {
   245  		oob = appendUDPSegmentSizeMsg(oob, packetSize)
   246  	} else if uint16(len(b)) != packetSize {
   247  		panic(fmt.Sprintf("inconsistent length. got: %d. expected %d", packetSize, len(b)))
   248  	}
   249  	n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
   250  	return n, err
   251  }
   252  
   253  func (c *oobConn) capabilities() connCapabilities {
   254  	return c.cap
   255  }
   256  
   257  type packetInfo struct {
   258  	addr    netip.Addr
   259  	ifIndex uint32
   260  }
   261  
   262  func (info *packetInfo) OOB() []byte {
   263  	if info == nil {
   264  		return nil
   265  	}
   266  	if info.addr.Is4() {
   267  		ip := info.addr.As4()
   268  		// struct in_pktinfo {
   269  		// 	unsigned int   ipi_ifindex;  /* Interface index */
   270  		// 	struct in_addr ipi_spec_dst; /* Local address */
   271  		// 	struct in_addr ipi_addr;     /* Header Destination address */
   272  		// };
   273  		cm := ipv4.ControlMessage{
   274  			Src:     ip[:],
   275  			IfIndex: int(info.ifIndex),
   276  		}
   277  		return cm.Marshal()
   278  	} else if info.addr.Is6() {
   279  		ip := info.addr.As16()
   280  		// struct in6_pktinfo {
   281  		// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   282  		// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   283  		// };
   284  		cm := ipv6.ControlMessage{
   285  			Src:     ip[:],
   286  			IfIndex: int(info.ifIndex),
   287  		}
   288  		return cm.Marshal()
   289  	}
   290  	return nil
   291  }