github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/sys_conn_oob.go (about)

     1  //go:build darwin || linux || freebsd
     2  
     3  package quic
     4  
     5  import (
     6  	"encoding/binary"
     7  	"errors"
     8  	"log"
     9  	"net"
    10  	"net/netip"
    11  	"os"
    12  	"strconv"
    13  	"sync"
    14  	"syscall"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  	"golang.org/x/sys/unix"
    21  
    22  	"github.com/apernet/quic-go/internal/protocol"
    23  	"github.com/apernet/quic-go/internal/utils"
    24  )
    25  
    26  const (
    27  	ecnMask       = 0x3
    28  	oobBufferSize = 128
    29  )
    30  
    31  // Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version.
    32  // They're both just aliases for x/net/internal/socket.Message.
    33  // This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
    34  var _ ipv4.Message = ipv6.Message{}
    35  
    36  type batchConn interface {
    37  	ReadBatch(ms []ipv4.Message, flags int) (int, error)
    38  }
    39  
    40  func inspectReadBuffer(c syscall.RawConn) (int, error) {
    41  	var size int
    42  	var serr error
    43  	if err := c.Control(func(fd uintptr) {
    44  		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
    45  	}); err != nil {
    46  		return 0, err
    47  	}
    48  	return size, serr
    49  }
    50  
    51  func inspectWriteBuffer(c syscall.RawConn) (int, error) {
    52  	var size int
    53  	var serr error
    54  	if err := c.Control(func(fd uintptr) {
    55  		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF)
    56  	}); err != nil {
    57  		return 0, err
    58  	}
    59  	return size, serr
    60  }
    61  
    62  func isECNDisabledUsingEnv() bool {
    63  	disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN"))
    64  	return err == nil && disabled
    65  }
    66  
    67  type oobConn struct {
    68  	OOBCapablePacketConn
    69  	batchConn batchConn
    70  
    71  	readPos uint8
    72  	// Packets received from the kernel, but not yet returned by ReadPacket().
    73  	messages []ipv4.Message
    74  	buffers  [batchSize]*packetBuffer
    75  
    76  	cap connCapabilities
    77  }
    78  
    79  var _ rawConn = &oobConn{}
    80  
    81  func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
    82  	rawConn, err := c.SyscallConn()
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	needsPacketInfo := false
    87  	if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
    88  		needsPacketInfo = true
    89  	}
    90  	// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
    91  	// Try enabling receiving of ECN and packet info for both IP versions.
    92  	// We expect at least one of those syscalls to succeed.
    93  	var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
    94  	if err := rawConn.Control(func(fd uintptr) {
    95  		errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
    96  		errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
    97  
    98  		if needsPacketInfo {
    99  			errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4PKTINFO, 1)
   100  			errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
   101  		}
   102  	}); err != nil {
   103  		return nil, err
   104  	}
   105  	switch {
   106  	case errECNIPv4 == nil && errECNIPv6 == nil:
   107  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
   108  	case errECNIPv4 == nil && errECNIPv6 != nil:
   109  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
   110  	case errECNIPv4 != nil && errECNIPv6 == nil:
   111  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
   112  	case errECNIPv4 != nil && errECNIPv6 != nil:
   113  		return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
   114  	}
   115  	if needsPacketInfo {
   116  		switch {
   117  		case errPIIPv4 == nil && errPIIPv6 == nil:
   118  			utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
   119  		case errPIIPv4 == nil && errPIIPv6 != nil:
   120  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
   121  		case errPIIPv4 != nil && errPIIPv6 == nil:
   122  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
   123  		case errPIIPv4 != nil && errPIIPv6 != nil:
   124  			return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
   125  		}
   126  	}
   127  
   128  	// Allows callers to pass in a connection that already satisfies batchConn interface
   129  	// to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor
   130  	// via SyscallConn(), and read it that way, which might not be what the caller wants.
   131  	var bc batchConn
   132  	if ibc, ok := c.(batchConn); ok {
   133  		bc = ibc
   134  	} else {
   135  		bc = ipv4.NewPacketConn(c)
   136  	}
   137  
   138  	msgs := make([]ipv4.Message, batchSize)
   139  	for i := range msgs {
   140  		// preallocate the [][]byte
   141  		msgs[i].Buffers = make([][]byte, 1)
   142  	}
   143  	oobConn := &oobConn{
   144  		OOBCapablePacketConn: c,
   145  		batchConn:            bc,
   146  		messages:             msgs,
   147  		readPos:              batchSize,
   148  		cap: connCapabilities{
   149  			DF:  supportsDF,
   150  			GSO: isGSOEnabled(rawConn),
   151  			ECN: isECNEnabled(),
   152  		},
   153  	}
   154  	for i := 0; i < batchSize; i++ {
   155  		oobConn.messages[i].OOB = make([]byte, oobBufferSize)
   156  	}
   157  	return oobConn, nil
   158  }
   159  
   160  var invalidCmsgOnceV4, invalidCmsgOnceV6 sync.Once
   161  
   162  func (c *oobConn) ReadPacket() (receivedPacket, error) {
   163  	if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
   164  		c.messages = c.messages[:batchSize]
   165  		// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
   166  		for i := uint8(0); i < c.readPos; i++ {
   167  			buffer := getPacketBuffer()
   168  			buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
   169  			c.buffers[i] = buffer
   170  			c.messages[i].Buffers[0] = c.buffers[i].Data
   171  		}
   172  		c.readPos = 0
   173  
   174  		n, err := c.batchConn.ReadBatch(c.messages, 0)
   175  		if n == 0 || err != nil {
   176  			return receivedPacket{}, err
   177  		}
   178  		c.messages = c.messages[:n]
   179  	}
   180  
   181  	msg := c.messages[c.readPos]
   182  	buffer := c.buffers[c.readPos]
   183  	c.readPos++
   184  
   185  	data := msg.OOB[:msg.NN]
   186  	p := receivedPacket{
   187  		remoteAddr: msg.Addr,
   188  		rcvTime:    time.Now(),
   189  		data:       msg.Buffers[0][:msg.N],
   190  		buffer:     buffer,
   191  	}
   192  	for len(data) > 0 {
   193  		hdr, body, remainder, err := unix.ParseOneSocketControlMessage(data)
   194  		if err != nil {
   195  			return receivedPacket{}, err
   196  		}
   197  		if hdr.Level == unix.IPPROTO_IP {
   198  			switch hdr.Type {
   199  			case msgTypeIPTOS:
   200  				p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
   201  			case ipv4PKTINFO:
   202  				ip, ifIndex, ok := parseIPv4PktInfo(body)
   203  				if ok {
   204  					p.info.addr = ip
   205  					p.info.ifIndex = ifIndex
   206  				} else {
   207  					invalidCmsgOnceV4.Do(func() {
   208  						log.Printf("Received invalid IPv4 packet info control message: %+x. "+
   209  							"This should never occur, please open a new issue and include details about the architecture.", body)
   210  					})
   211  				}
   212  			}
   213  		}
   214  		if hdr.Level == unix.IPPROTO_IPV6 {
   215  			switch hdr.Type {
   216  			case unix.IPV6_TCLASS:
   217  				p.ecn = protocol.ParseECNHeaderBits(body[0] & ecnMask)
   218  			case unix.IPV6_PKTINFO:
   219  				// struct in6_pktinfo {
   220  				// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   221  				// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   222  				// };
   223  				if len(body) == 20 {
   224  					p.info.addr = netip.AddrFrom16(*(*[16]byte)(body[:16])).Unmap()
   225  					p.info.ifIndex = binary.LittleEndian.Uint32(body[16:])
   226  				} else {
   227  					invalidCmsgOnceV6.Do(func() {
   228  						log.Printf("Received invalid IPv6 packet info control message: %+x. "+
   229  							"This should never occur, please open a new issue and include details about the architecture.", body)
   230  					})
   231  				}
   232  			}
   233  		}
   234  		data = remainder
   235  	}
   236  	return p, nil
   237  }
   238  
   239  // WritePacket writes a new packet.
   240  func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gsoSize uint16, ecn protocol.ECN) (int, error) {
   241  	oob := packetInfoOOB
   242  	if gsoSize > 0 {
   243  		if !c.capabilities().GSO {
   244  			panic("GSO disabled")
   245  		}
   246  		oob = appendUDPSegmentSizeMsg(oob, gsoSize)
   247  	}
   248  	if ecn != protocol.ECNUnsupported {
   249  		if !c.capabilities().ECN {
   250  			panic("tried to send an ECN-marked packet although ECN is disabled")
   251  		}
   252  		if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
   253  			if remoteUDPAddr.IP.To4() != nil {
   254  				oob = appendIPv4ECNMsg(oob, ecn)
   255  			} else {
   256  				oob = appendIPv6ECNMsg(oob, ecn)
   257  			}
   258  		}
   259  	}
   260  	n, _, err := c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
   261  	return n, err
   262  }
   263  
   264  func (c *oobConn) capabilities() connCapabilities {
   265  	return c.cap
   266  }
   267  
   268  type packetInfo struct {
   269  	addr    netip.Addr
   270  	ifIndex uint32
   271  }
   272  
   273  func (info *packetInfo) OOB() []byte {
   274  	if info == nil {
   275  		return nil
   276  	}
   277  	if info.addr.Is4() {
   278  		ip := info.addr.As4()
   279  		// struct in_pktinfo {
   280  		// 	unsigned int   ipi_ifindex;  /* Interface index */
   281  		// 	struct in_addr ipi_spec_dst; /* Local address */
   282  		// 	struct in_addr ipi_addr;     /* Header Destination address */
   283  		// };
   284  		cm := ipv4.ControlMessage{
   285  			Src:     ip[:],
   286  			IfIndex: int(info.ifIndex),
   287  		}
   288  		return cm.Marshal()
   289  	} else if info.addr.Is6() {
   290  		ip := info.addr.As16()
   291  		// struct in6_pktinfo {
   292  		// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   293  		// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   294  		// };
   295  		cm := ipv6.ControlMessage{
   296  			Src:     ip[:],
   297  			IfIndex: int(info.ifIndex),
   298  		}
   299  		return cm.Marshal()
   300  	}
   301  	return nil
   302  }
   303  
   304  func appendIPv4ECNMsg(b []byte, val protocol.ECN) []byte {
   305  	startLen := len(b)
   306  	b = append(b, make([]byte, unix.CmsgSpace(ecnIPv4DataLen))...)
   307  	h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
   308  	h.Level = syscall.IPPROTO_IP
   309  	h.Type = unix.IP_TOS
   310  	h.SetLen(unix.CmsgLen(ecnIPv4DataLen))
   311  
   312  	// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
   313  	offset := startLen + unix.CmsgSpace(0)
   314  	b[offset] = val.ToHeaderBits()
   315  	return b
   316  }
   317  
   318  func appendIPv6ECNMsg(b []byte, val protocol.ECN) []byte {
   319  	startLen := len(b)
   320  	const dataLen = 4
   321  	b = append(b, make([]byte, unix.CmsgSpace(dataLen))...)
   322  	h := (*unix.Cmsghdr)(unsafe.Pointer(&b[startLen]))
   323  	h.Level = syscall.IPPROTO_IPV6
   324  	h.Type = unix.IPV6_TCLASS
   325  	h.SetLen(unix.CmsgLen(dataLen))
   326  
   327  	// UnixRights uses the private `data` method, but I *think* this achieves the same goal.
   328  	offset := startLen + unix.CmsgSpace(0)
   329  	b[offset] = val.ToHeaderBits()
   330  	return b
   331  }