github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/conn_oob.go (about)

     1  //go:build darwin || linux || freebsd
     2  // +build darwin linux freebsd
     3  
     4  package quic
     5  
     6  import (
     7  	"encoding/binary"
     8  	"errors"
     9  	"fmt"
    10  	"net"
    11  	"syscall"
    12  	"time"
    13  
    14  	"golang.org/x/net/ipv4"
    15  	"golang.org/x/net/ipv6"
    16  	"golang.org/x/sys/unix"
    17  
    18  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    19  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/utils"
    20  )
    21  
    22  const (
    23  	ecnMask       = 0x3
    24  	oobBufferSize = 128
    25  )
    26  
    27  // Contrary to what the naming suggests, the ipv{4,6}.Message is not dependent on the IP version.
    28  // They're both just aliases for x/net/internal/socket.Message.
    29  // This means we can use this struct to read from a socket that receives both IPv4 and IPv6 messages.
    30  var _ ipv4.Message = ipv6.Message{}
    31  
    32  type batchConn interface {
    33  	ReadBatch(ms []ipv4.Message, flags int) (int, error)
    34  }
    35  
    36  func inspectReadBuffer(c interface{}) (int, error) {
    37  	conn, ok := c.(interface {
    38  		SyscallConn() (syscall.RawConn, error)
    39  	})
    40  	if !ok {
    41  		return 0, errors.New("doesn't have a SyscallConn")
    42  	}
    43  	rawConn, err := conn.SyscallConn()
    44  	if err != nil {
    45  		return 0, fmt.Errorf("couldn't get syscall.RawConn: %w", err)
    46  	}
    47  	var size int
    48  	var serr error
    49  	if err := rawConn.Control(func(fd uintptr) {
    50  		size, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF)
    51  	}); err != nil {
    52  		return 0, err
    53  	}
    54  	return size, serr
    55  }
    56  
    57  type oobConn struct {
    58  	OOBCapablePacketConn
    59  	batchConn batchConn
    60  
    61  	readPos uint8
    62  	// Packets received from the kernel, but not yet returned by ReadPacket().
    63  	messages []ipv4.Message
    64  	buffers  [batchSize]*packetBuffer
    65  }
    66  
    67  var _ connection = &oobConn{}
    68  
    69  func newConn(c OOBCapablePacketConn) (*oobConn, error) {
    70  	rawConn, err := c.SyscallConn()
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	needsPacketInfo := false
    75  	if udpAddr, ok := c.LocalAddr().(*net.UDPAddr); ok && udpAddr.IP.IsUnspecified() {
    76  		needsPacketInfo = true
    77  	}
    78  	// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
    79  	// Try enabling receiving of ECN and packet info for both IP versions.
    80  	// We expect at least one of those syscalls to succeed.
    81  	var errECNIPv4, errECNIPv6, errPIIPv4, errPIIPv6 error
    82  	if err := rawConn.Control(func(fd uintptr) {
    83  		errECNIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
    84  		errECNIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
    85  
    86  		if needsPacketInfo {
    87  			errPIIPv4 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, ipv4RECVPKTINFO, 1)
    88  			errPIIPv6 = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, ipv6RECVPKTINFO, 1)
    89  		}
    90  	}); err != nil {
    91  		return nil, err
    92  	}
    93  	switch {
    94  	case errECNIPv4 == nil && errECNIPv6 == nil:
    95  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4 and IPv6.")
    96  	case errECNIPv4 == nil && errECNIPv6 != nil:
    97  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv4.")
    98  	case errECNIPv4 != nil && errECNIPv6 == nil:
    99  		utils.DefaultLogger.Debugf("Activating reading of ECN bits for IPv6.")
   100  	case errECNIPv4 != nil && errECNIPv6 != nil:
   101  		return nil, errors.New("activating ECN failed for both IPv4 and IPv6")
   102  	}
   103  	if needsPacketInfo {
   104  		switch {
   105  		case errPIIPv4 == nil && errPIIPv6 == nil:
   106  			utils.DefaultLogger.Debugf("Activating reading of packet info for IPv4 and IPv6.")
   107  		case errPIIPv4 == nil && errPIIPv6 != nil:
   108  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv4.")
   109  		case errPIIPv4 != nil && errPIIPv6 == nil:
   110  			utils.DefaultLogger.Debugf("Activating reading of packet info bits for IPv6.")
   111  		case errPIIPv4 != nil && errPIIPv6 != nil:
   112  			return nil, errors.New("activating packet info failed for both IPv4 and IPv6")
   113  		}
   114  	}
   115  
   116  	// Allows callers to pass in a connection that already satisfies batchConn interface
   117  	// to make use of the optimisation. Otherwise, ipv4.NewPacketConn would unwrap the file descriptor
   118  	// via SyscallConn(), and read it that way, which might not be what the caller wants.
   119  	var bc batchConn
   120  	if ibc, ok := c.(batchConn); ok {
   121  		bc = ibc
   122  	} else {
   123  		bc = ipv4.NewPacketConn(c)
   124  	}
   125  
   126  	oobConn := &oobConn{
   127  		OOBCapablePacketConn: c,
   128  		batchConn:            bc,
   129  		messages:             make([]ipv4.Message, batchSize),
   130  		readPos:              batchSize,
   131  	}
   132  	for i := 0; i < batchSize; i++ {
   133  		oobConn.messages[i].OOB = make([]byte, oobBufferSize)
   134  	}
   135  	return oobConn, nil
   136  }
   137  
   138  func (c *oobConn) ReadPacket() (*receivedPacket, error) {
   139  	if len(c.messages) == int(c.readPos) { // all messages read. Read the next batch of messages.
   140  		c.messages = c.messages[:batchSize]
   141  		// replace buffers data buffers up to the packet that has been consumed during the last ReadBatch call
   142  		for i := uint8(0); i < c.readPos; i++ {
   143  			buffer := getPacketBuffer()
   144  			buffer.Data = buffer.Data[:protocol.MaxPacketBufferSize]
   145  			c.buffers[i] = buffer
   146  			c.messages[i].Buffers = [][]byte{c.buffers[i].Data}
   147  		}
   148  		c.readPos = 0
   149  
   150  		n, err := c.batchConn.ReadBatch(c.messages, 0)
   151  		if n == 0 || err != nil {
   152  			return nil, err
   153  		}
   154  		c.messages = c.messages[:n]
   155  	}
   156  
   157  	msg := c.messages[c.readPos]
   158  	buffer := c.buffers[c.readPos]
   159  	c.readPos++
   160  	ctrlMsgs, err := unix.ParseSocketControlMessage(msg.OOB[:msg.NN])
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	var ecn protocol.ECN
   165  	var destIP net.IP
   166  	var ifIndex uint32
   167  	for _, ctrlMsg := range ctrlMsgs {
   168  		if ctrlMsg.Header.Level == unix.IPPROTO_IP {
   169  			switch ctrlMsg.Header.Type {
   170  			case msgTypeIPTOS:
   171  				ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
   172  			case msgTypeIPv4PKTINFO:
   173  				// struct in_pktinfo {
   174  				// 	unsigned int   ipi_ifindex;  /* Interface index */
   175  				// 	struct in_addr ipi_spec_dst; /* Local address */
   176  				// 	struct in_addr ipi_addr;     /* Header Destination
   177  				// 									address */
   178  				// };
   179  				ip := make([]byte, 4)
   180  				if len(ctrlMsg.Data) == 12 {
   181  					ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data)
   182  					copy(ip, ctrlMsg.Data[8:12])
   183  				} else if len(ctrlMsg.Data) == 4 {
   184  					// FreeBSD
   185  					copy(ip, ctrlMsg.Data)
   186  				}
   187  				destIP = net.IP(ip)
   188  			}
   189  		}
   190  		if ctrlMsg.Header.Level == unix.IPPROTO_IPV6 {
   191  			switch ctrlMsg.Header.Type {
   192  			case unix.IPV6_TCLASS:
   193  				ecn = protocol.ECN(ctrlMsg.Data[0] & ecnMask)
   194  			case msgTypeIPv6PKTINFO:
   195  				// struct in6_pktinfo {
   196  				// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   197  				// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   198  				// };
   199  				if len(ctrlMsg.Data) == 20 {
   200  					ip := make([]byte, 16)
   201  					copy(ip, ctrlMsg.Data[:16])
   202  					destIP = net.IP(ip)
   203  					ifIndex = binary.LittleEndian.Uint32(ctrlMsg.Data[16:])
   204  				}
   205  			}
   206  		}
   207  	}
   208  	var info *packetInfo
   209  	if destIP != nil {
   210  		info = &packetInfo{
   211  			addr:    destIP,
   212  			ifIndex: ifIndex,
   213  		}
   214  	}
   215  	return &receivedPacket{
   216  		remoteAddr: msg.Addr,
   217  		rcvTime:    time.Now(),
   218  		data:       msg.Buffers[0][:msg.N],
   219  		ecn:        ecn,
   220  		info:       info,
   221  		buffer:     buffer,
   222  	}, nil
   223  }
   224  
   225  func (c *oobConn) WritePacket(b []byte, addr net.Addr, oob []byte) (n int, err error) {
   226  	n, _, err = c.OOBCapablePacketConn.WriteMsgUDP(b, oob, addr.(*net.UDPAddr))
   227  	return n, err
   228  }
   229  
   230  func (info *packetInfo) OOB() []byte {
   231  	if info == nil {
   232  		return nil
   233  	}
   234  	if ip4 := info.addr.To4(); ip4 != nil {
   235  		// struct in_pktinfo {
   236  		// 	unsigned int   ipi_ifindex;  /* Interface index */
   237  		// 	struct in_addr ipi_spec_dst; /* Local address */
   238  		// 	struct in_addr ipi_addr;     /* Header Destination address */
   239  		// };
   240  		cm := ipv4.ControlMessage{
   241  			Src:     ip4,
   242  			IfIndex: int(info.ifIndex),
   243  		}
   244  		return cm.Marshal()
   245  	} else if len(info.addr) == 16 {
   246  		// struct in6_pktinfo {
   247  		// 	struct in6_addr ipi6_addr;    /* src/dst IPv6 address */
   248  		// 	unsigned int    ipi6_ifindex; /* send/recv interface index */
   249  		// };
   250  		cm := ipv6.ControlMessage{
   251  			Src:     info.addr,
   252  			IfIndex: int(info.ifIndex),
   253  		}
   254  		return cm.Marshal()
   255  	}
   256  	return nil
   257  }