golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/udp_msg.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21 && !quicbasicnet && (darwin || linux)
     6  
     7  package quic
     8  
     9  import (
    10  	"encoding/binary"
    11  	"net"
    12  	"net/netip"
    13  	"sync"
    14  	"unsafe"
    15  
    16  	"golang.org/x/sys/unix"
    17  )
    18  
    19  // Network interface for platforms using sendmsg/recvmsg with cmsgs.
    20  
    21  type netUDPConn struct {
    22  	c         *net.UDPConn
    23  	localAddr netip.AddrPort
    24  }
    25  
    26  func newNetUDPConn(uc *net.UDPConn) (*netUDPConn, error) {
    27  	a, _ := uc.LocalAddr().(*net.UDPAddr)
    28  	localAddr := a.AddrPort()
    29  	if localAddr.Addr().IsUnspecified() {
    30  		// If the conn is not bound to a specified (non-wildcard) address,
    31  		// then set localAddr.Addr to an invalid netip.Addr.
    32  		// This better conveys that this is not an address we should be using,
    33  		// and is a bit more efficient to test against.
    34  		localAddr = netip.AddrPortFrom(netip.Addr{}, localAddr.Port())
    35  	}
    36  
    37  	sc, err := uc.SyscallConn()
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	sc.Control(func(fd uintptr) {
    42  		// Ask for ECN info and (when we aren't bound to a fixed local address)
    43  		// destination info.
    44  		//
    45  		// If any of these calls fail, we won't get the requested information.
    46  		// That's fine, we'll gracefully handle the lack.
    47  		unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
    48  		unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
    49  		if !localAddr.IsValid() {
    50  			unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
    51  			unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
    52  		}
    53  	})
    54  
    55  	return &netUDPConn{
    56  		c:         uc,
    57  		localAddr: localAddr,
    58  	}, nil
    59  }
    60  
    61  func (c *netUDPConn) Close() error { return c.c.Close() }
    62  
    63  func (c *netUDPConn) LocalAddr() netip.AddrPort {
    64  	a, _ := c.c.LocalAddr().(*net.UDPAddr)
    65  	return a.AddrPort()
    66  }
    67  
    68  func (c *netUDPConn) Read(f func(*datagram)) {
    69  	// We shouldn't ever see all of these messages at the same time,
    70  	// but the total is small so just allocate enough space for everything we use.
    71  	const (
    72  		inPktinfoSize  = 12 // int + in_addr + in_addr
    73  		in6PktinfoSize = 20 // in6_addr + int
    74  		ipTOSSize      = 4
    75  		ipv6TclassSize = 4
    76  	)
    77  	control := make([]byte, 0+
    78  		unix.CmsgSpace(inPktinfoSize)+
    79  		unix.CmsgSpace(in6PktinfoSize)+
    80  		unix.CmsgSpace(ipTOSSize)+
    81  		unix.CmsgSpace(ipv6TclassSize))
    82  
    83  	for {
    84  		d := newDatagram()
    85  		n, controlLen, _, peerAddr, err := c.c.ReadMsgUDPAddrPort(d.b, control)
    86  		if err != nil {
    87  			return
    88  		}
    89  		if n == 0 {
    90  			continue
    91  		}
    92  		d.localAddr = c.localAddr
    93  		d.peerAddr = unmapAddrPort(peerAddr)
    94  		d.b = d.b[:n]
    95  		parseControl(d, control[:controlLen])
    96  		f(d)
    97  	}
    98  }
    99  
   100  var cmsgPool = sync.Pool{
   101  	New: func() any {
   102  		return new([]byte)
   103  	},
   104  }
   105  
   106  func (c *netUDPConn) Write(dgram datagram) error {
   107  	controlp := cmsgPool.Get().(*[]byte)
   108  	control := *controlp
   109  	defer func() {
   110  		*controlp = control[:0]
   111  		cmsgPool.Put(controlp)
   112  	}()
   113  
   114  	localIP := dgram.localAddr.Addr()
   115  	if localIP.IsValid() {
   116  		if localIP.Is4() {
   117  			control = appendCmsgIPSourceAddrV4(control, localIP)
   118  		} else {
   119  			control = appendCmsgIPSourceAddrV6(control, localIP)
   120  		}
   121  	}
   122  	if dgram.ecn != ecnNotECT {
   123  		if dgram.peerAddr.Addr().Is4() {
   124  			control = appendCmsgECNv4(control, dgram.ecn)
   125  		} else {
   126  			control = appendCmsgECNv6(control, dgram.ecn)
   127  		}
   128  	}
   129  
   130  	_, _, err := c.c.WriteMsgUDPAddrPort(dgram.b, control, dgram.peerAddr)
   131  	return err
   132  }
   133  
   134  func parseControl(d *datagram, control []byte) {
   135  	for len(control) > 0 {
   136  		hdr, data, remainder, err := unix.ParseOneSocketControlMessage(control)
   137  		if err != nil {
   138  			return
   139  		}
   140  		control = remainder
   141  		switch hdr.Level {
   142  		case unix.IPPROTO_IP:
   143  			switch hdr.Type {
   144  			case unix.IP_TOS, unix.IP_RECVTOS:
   145  				// (Linux sets the type to IP_TOS, Darwin to IP_RECVTOS,
   146  				// just check for both.)
   147  				if ecn, ok := parseIPTOS(data); ok {
   148  					d.ecn = ecn
   149  				}
   150  			case unix.IP_PKTINFO:
   151  				if a, ok := parseInPktinfo(data); ok {
   152  					d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
   153  				}
   154  			}
   155  		case unix.IPPROTO_IPV6:
   156  			switch hdr.Type {
   157  			case unix.IPV6_TCLASS:
   158  				// 32-bit integer containing the traffic class field.
   159  				// The low two bits are the ECN field.
   160  				if ecn, ok := parseIPv6TCLASS(data); ok {
   161  					d.ecn = ecn
   162  				}
   163  			case unix.IPV6_PKTINFO:
   164  				if a, ok := parseIn6Pktinfo(data); ok {
   165  					d.localAddr = netip.AddrPortFrom(a, d.localAddr.Port())
   166  				}
   167  			}
   168  		}
   169  	}
   170  }
   171  
   172  // IPV6_TCLASS is specified by RFC 3542 as an int.
   173  
   174  func parseIPv6TCLASS(b []byte) (ecnBits, bool) {
   175  	if len(b) != 4 {
   176  		return 0, false
   177  	}
   178  	return ecnBits(binary.NativeEndian.Uint32(b) & ecnMask), true
   179  }
   180  
   181  func appendCmsgECNv6(b []byte, ecn ecnBits) []byte {
   182  	b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 4)
   183  	binary.NativeEndian.PutUint32(data, uint32(ecn))
   184  	return b
   185  }
   186  
   187  // struct in_pktinfo {
   188  //   unsigned int   ipi_ifindex;  /* send/recv interface index */
   189  //   struct in_addr ipi_spec_dst; /* Local address */
   190  //   struct in_addr ipi_addr;     /* IP Header dst address */
   191  // };
   192  
   193  // parseInPktinfo returns the destination address from an IP_PKTINFO.
   194  func parseInPktinfo(b []byte) (dst netip.Addr, ok bool) {
   195  	if len(b) != 12 {
   196  		return netip.Addr{}, false
   197  	}
   198  	return netip.AddrFrom4([4]byte(b[8:][:4])), true
   199  }
   200  
   201  // appendCmsgIPSourceAddrV4 appends an IP_PKTINFO setting the source address
   202  // for an outbound datagram.
   203  func appendCmsgIPSourceAddrV4(b []byte, src netip.Addr) []byte {
   204  	// struct in_pktinfo {
   205  	//   unsigned int   ipi_ifindex;  /* send/recv interface index */
   206  	//   struct in_addr ipi_spec_dst; /* Local address */
   207  	//   struct in_addr ipi_addr;     /* IP Header dst address */
   208  	// };
   209  	b, data := appendCmsg(b, unix.IPPROTO_IP, unix.IP_PKTINFO, 12)
   210  	ip := src.As4()
   211  	copy(data[4:], ip[:])
   212  	return b
   213  }
   214  
   215  // struct in6_pktinfo {
   216  //   struct in6_addr  ipi6_addr;    /* src/dst IPv6 address */
   217  //   unsigned int     ipi6_ifindex; /* send/recv interface index */
   218  // };
   219  
   220  // parseIn6Pktinfo returns the destination address from an IPV6_PKTINFO.
   221  func parseIn6Pktinfo(b []byte) (netip.Addr, bool) {
   222  	if len(b) != 20 {
   223  		return netip.Addr{}, false
   224  	}
   225  	return netip.AddrFrom16([16]byte(b[:16])).Unmap(), true
   226  }
   227  
   228  // appendCmsgIPSourceAddrV6 appends an IPV6_PKTINFO setting the source address
   229  // for an outbound datagram.
   230  func appendCmsgIPSourceAddrV6(b []byte, src netip.Addr) []byte {
   231  	b, data := appendCmsg(b, unix.IPPROTO_IPV6, unix.IPV6_PKTINFO, 20)
   232  	ip := src.As16()
   233  	copy(data[0:], ip[:])
   234  	return b
   235  }
   236  
   237  // appendCmsg appends a cmsg with the given level, type, and size to b.
   238  // It returns the new buffer, and the data section of the cmsg.
   239  func appendCmsg(b []byte, level, typ int32, size int) (_, data []byte) {
   240  	off := len(b)
   241  	b = append(b, make([]byte, unix.CmsgSpace(size))...)
   242  	h := (*unix.Cmsghdr)(unsafe.Pointer(&b[off]))
   243  	h.Level = level
   244  	h.Type = typ
   245  	h.SetLen(unix.CmsgLen(size))
   246  	return b, b[off+unix.CmsgSpace(0):][:size]
   247  }