github.com/cilium/cilium@v1.16.2/pkg/fqdn/dnsproxy/udp.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package dnsproxy
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"strconv"
    14  	"sync"
    15  	"syscall"
    16  	"unsafe"
    17  
    18  	"github.com/cilium/dns"
    19  	"golang.org/x/net/ipv4"
    20  	"golang.org/x/net/ipv6"
    21  	"golang.org/x/sys/unix"
    22  
    23  	"github.com/cilium/cilium/pkg/datapath/linux/linux_defaults"
    24  	"github.com/cilium/cilium/pkg/fqdn/proxy/ipfamily"
    25  	"github.com/cilium/cilium/pkg/option"
    26  )
    27  
    28  const pseudoHeaderLength = 40
    29  
    30  // This is the required size of the OOB buffer to pass to ReadMsgUDP.
    31  var udpOOBSize = func() int {
    32  	var hdr unix.Cmsghdr
    33  	var addr unix.RawSockaddrInet6
    34  	return int(unsafe.Sizeof(hdr) + unsafe.Sizeof(addr))
    35  }()
    36  
    37  // Set up new SessionUDPFactory with dedicated raw socket for sending responses.
    38  //   - Must use a raw UDP socket for sending responses so that we can send
    39  //     from a specific port without binding to it.
    40  //   - The raw UDP socket must be bound to a specific IP address to prevent
    41  //     it receiving ALL UDP packets on the host.
    42  //   - We use oob data to override the source IP address when sending
    43  //   - Must use separate sockets for IPv4/IPv6, as sending to a v6-mapped
    44  //     v4 address from a socket bound to "::1" does not work due to kernel
    45  //     checking that a route exists from the source address before
    46  //     the source address is replaced with the (transparently) changed one
    47  func NewSessionUDPFactory(ipFamily ipfamily.IPFamily) (dns.SessionUDPFactory, error) {
    48  	rawResponseConn, err := bindResponseUDPConnection(ipFamily)
    49  	if err != nil {
    50  		return nil, fmt.Errorf("failed to open raw UDP %s socket for DNS Proxy: %w", ipFamily.Name, err)
    51  	}
    52  
    53  	return &sessionUDPFactory{rawResponseConn: rawResponseConn}, nil
    54  }
    55  
    56  type sessionUDPFactory struct {
    57  	// A pool for UDP message buffers.
    58  	udpPool sync.Pool
    59  
    60  	// rawResponseConn is used to send the response
    61  	// See sessionUDP.WriteResponse
    62  	rawResponseConn *net.IPConn
    63  }
    64  
    65  // sessionUDP implements the dns.SessionUDP, holding the remote address and the associated
    66  // out-of-band data.
    67  type sessionUDP struct {
    68  	f     *sessionUDPFactory // owner
    69  	conn  *net.UDPConn       // UDP socket for receiving both IPv4 and IPv6
    70  	raddr *net.UDPAddr
    71  	laddr *net.UDPAddr
    72  	m     []byte
    73  	oob   []byte
    74  }
    75  
    76  // Set the socket options needed for tranparent proxying for the listening socket
    77  // IP(V6)_TRANSPARENT allows socket to receive packets with any destination address/port
    78  // IP(V6)_RECVORIGDSTADDR tells the kernel to pass the original destination address/port on recvmsg
    79  // By design, a socket of a DNS Server can only receive IPv4 or IPv6 traffic.
    80  func transparentSetsockopt(fd int, ipFamily ipfamily.IPFamily) error {
    81  	if err := unix.SetsockoptInt(fd, ipFamily.SocketOptsFamily, ipFamily.SocketOptsTransparent, 1); err != nil {
    82  		return fmt.Errorf("setsockopt(IP_TRANSPARENT) for %s failed: %w", ipFamily.Name, err)
    83  	}
    84  	if err := unix.SetsockoptInt(fd, ipFamily.SocketOptsFamily, ipFamily.SocketOptsRecvOrigDstAddr, 1); err != nil {
    85  		return fmt.Errorf("setsockopt(IP_RECVORIGDSTADDR) for %s failed: %w", ipFamily.Name, err)
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  // listenConfig sets the socket options for the fqdn proxy transparent socket.
    92  // Note that it is also used for TCP sockets.
    93  func listenConfig(mark int, ipFamily ipfamily.IPFamily) *net.ListenConfig {
    94  	return &net.ListenConfig{
    95  		Control: func(_, _ string, c syscall.RawConn) error {
    96  			var opErr error
    97  			err := c.Control(func(fd uintptr) {
    98  				if err := transparentSetsockopt(int(fd), ipFamily); err != nil {
    99  					opErr = err
   100  					return
   101  				}
   102  				if mark != 0 {
   103  					if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, mark); err != nil {
   104  						opErr = fmt.Errorf("setsockopt(SO_MARK) failed: %w", err)
   105  						return
   106  					}
   107  				}
   108  				if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1); err != nil {
   109  					opErr = fmt.Errorf("setsockopt(SO_REUSEADDR) failed: %w", err)
   110  					return
   111  				}
   112  				if !option.Config.EnableBPFTProxy {
   113  					if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1); err != nil {
   114  						opErr = fmt.Errorf("setsockopt(SO_REUSEPORT) failed: %w", err)
   115  						return
   116  					}
   117  				}
   118  			})
   119  			if err != nil {
   120  				return err
   121  			}
   122  
   123  			return opErr
   124  		},
   125  	}
   126  }
   127  
   128  func bindResponseUDPConnection(ipFamily ipfamily.IPFamily) (*net.IPConn, error) {
   129  	// Mark outgoing packets as proxy egress return traffic (0x0b00)
   130  	conn, err := listenConfig(linux_defaults.MagicMarkEgress, ipFamily).ListenPacket(context.Background(), "ip:udp", ipFamily.Localhost)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("failed to bind UDP for address %s: %w", ipFamily.Localhost, err)
   133  	}
   134  	return conn.(*net.IPConn), nil
   135  }
   136  
   137  // SetSocketOptions set's up 'conn' to be used with a SessionUDP.
   138  func (f *sessionUDPFactory) SetSocketOptions(_ *net.UDPConn) error {
   139  	// Response connections (IPv4 & IPv6) will be used to response.
   140  	// They are already properly setup in NewSessionUDPFactory.
   141  	return nil
   142  }
   143  
   144  // InitPool initializes a pool of buffers to be used with SessionUDP.
   145  func (f *sessionUDPFactory) InitPool(msgSize int) {
   146  	f.udpPool.New = func() interface{} {
   147  		return &sessionUDP{
   148  			f:   f,
   149  			m:   make([]byte, msgSize),
   150  			oob: make([]byte, udpOOBSize),
   151  		}
   152  	}
   153  }
   154  
   155  // ReadRequest reads a single request from 'conn' and returns the request context
   156  func (f *sessionUDPFactory) ReadRequest(conn *net.UDPConn) ([]byte, dns.SessionUDP, error) {
   157  	s := f.udpPool.Get().(*sessionUDP)
   158  	n, oobn, _, raddr, err := conn.ReadMsgUDP(s.m, s.oob)
   159  	if err != nil {
   160  		s.Discard()
   161  		return nil, nil, err
   162  	}
   163  	s.conn = conn
   164  	s.raddr = raddr
   165  	s.m = s.m[:n]        // Re-slice to the actual size
   166  	s.oob = s.oob[:oobn] // Re-slice to the actual size
   167  	s.laddr, err = parseDstFromOOB(s.oob)
   168  	if err != nil {
   169  		s.Discard()
   170  		return nil, nil, err
   171  	}
   172  	return s.m, s, err
   173  }
   174  
   175  func (f *sessionUDPFactory) ReadRequestConn(conn net.PacketConn) ([]byte, net.Addr, error) {
   176  	return []byte{}, nil, errors.New("ReadRequestConn is not supported")
   177  }
   178  
   179  // Discard returns 's' to the factory pool
   180  func (s *sessionUDP) Discard() {
   181  	s.conn = nil
   182  	s.raddr = nil
   183  	s.laddr = nil
   184  	s.m = s.m[:cap(s.m)]
   185  	s.oob = s.oob[:cap(s.oob)]
   186  
   187  	s.f.udpPool.Put(s)
   188  }
   189  
   190  // RemoteAddr returns the remote network address.
   191  func (s *sessionUDP) RemoteAddr() net.Addr { return s.raddr }
   192  
   193  // LocalAddr returns the local network address for the current request.
   194  func (s *sessionUDP) LocalAddr() net.Addr { return s.laddr }
   195  
   196  // WriteResponse writes a response to a request received earlier.
   197  // It uses the raw udp connections (IPv4 or IPv6) from its sessionUDPFactory.
   198  func (s *sessionUDP) WriteResponse(b []byte) (int, error) {
   199  	// Must give the UDP header to get the source port right.
   200  	// Reuse the msg buffer, figure out if golang can do gatter-scather IO
   201  	// with raw sockets?
   202  	l := len(b)
   203  	bb := bytes.NewBuffer(s.m[:0])
   204  	binary.Write(bb, binary.BigEndian, uint16(s.laddr.Port))
   205  	binary.Write(bb, binary.BigEndian, uint16(s.raddr.Port))
   206  	binary.Write(bb, binary.BigEndian, uint16(8+l))
   207  	binary.Write(bb, binary.BigEndian, uint16(0)) // checksum
   208  	bb.Write(b)
   209  	buf := bb.Bytes()
   210  
   211  	// A UDP checksum is required for IPv6
   212  	if s.raddr.IP.To4() == nil {
   213  		// Compute the UDP the checksum
   214  		binary.BigEndian.PutUint16(buf[6:8], computeIPv6Checksum(s.laddr.IP, s.raddr.IP, buf))
   215  	}
   216  
   217  	var n int
   218  	var err error
   219  	dst := net.IPAddr{
   220  		IP: s.raddr.IP,
   221  	}
   222  
   223  	n, _, err = s.f.rawResponseConn.WriteMsgIP(buf, s.controlMessage(s.laddr), &dst)
   224  	if err != nil {
   225  		log.WithError(err).Warning("WriteMsgIP failed")
   226  	} else {
   227  		log.Debugf("dnsproxy: Wrote DNS response (%d/%d bytes) from %s to %s", n-8, l, s.laddr.String(), s.raddr.String())
   228  	}
   229  	return n, err
   230  }
   231  
   232  // parseDstFromOOB takes oob data and returns the destination IP.
   233  func parseDstFromOOB(oob []byte) (*net.UDPAddr, error) {
   234  	msgs, err := unix.ParseSocketControlMessage(oob)
   235  	if err != nil {
   236  		return nil, fmt.Errorf("parsing socket control message: %w", err)
   237  	}
   238  
   239  	for _, msg := range msgs {
   240  		if msg.Header.Level == unix.SOL_IP && msg.Header.Type == unix.IP_ORIGDSTADDR {
   241  			pp := &unix.RawSockaddrInet4{}
   242  			// Address family is in native byte order
   243  			family := *(*uint16)(unsafe.Pointer(&msg.Data[unsafe.Offsetof(pp.Family)]))
   244  			if family != unix.AF_INET {
   245  				return nil, fmt.Errorf("original destination is not IPv4")
   246  			}
   247  			// Port is in big-endian byte order
   248  			if err = binary.Read(bytes.NewReader(msg.Data), binary.BigEndian, pp); err != nil {
   249  				return nil, fmt.Errorf("reading original destination address: %w", err)
   250  			}
   251  			laddr := &net.UDPAddr{
   252  				IP:   net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
   253  				Port: int(pp.Port),
   254  			}
   255  			return laddr, nil
   256  		}
   257  		if msg.Header.Level == unix.SOL_IPV6 && msg.Header.Type == unix.IPV6_ORIGDSTADDR {
   258  			pp := &unix.RawSockaddrInet6{}
   259  			// Address family is in native byte order
   260  			family := *(*uint16)(unsafe.Pointer(&msg.Data[unsafe.Offsetof(pp.Family)]))
   261  			if family != unix.AF_INET6 {
   262  				return nil, fmt.Errorf("original destination is not IPv6")
   263  			}
   264  			// Scope ID is in native byte order
   265  			scopeId := *(*uint32)(unsafe.Pointer(&msg.Data[unsafe.Offsetof(pp.Scope_id)]))
   266  			// Rest of the data is big-endian (port)
   267  			if err = binary.Read(bytes.NewReader(msg.Data), binary.BigEndian, pp); err != nil {
   268  				return nil, fmt.Errorf("reading original destination address: %w", err)
   269  			}
   270  			laddr := &net.UDPAddr{
   271  				IP:   net.IP(pp.Addr[:]),
   272  				Port: int(pp.Port),
   273  				Zone: strconv.Itoa(int(scopeId)),
   274  			}
   275  			return laddr, nil
   276  		}
   277  	}
   278  	return nil, fmt.Errorf("no original destination found")
   279  }
   280  
   281  // controlMessage returns the oob data with the given source address
   282  func (s *sessionUDP) controlMessage(src *net.UDPAddr) []byte {
   283  	// If the src is definitely an IPv6, then use ipv6's ControlMessage to
   284  	// respond otherwise use ipv4's because ipv6's marshal ignores ipv4
   285  	// addresses.
   286  	if src.IP.To4() == nil {
   287  		cm := new(ipv6.ControlMessage)
   288  		cm.Src = src.IP
   289  		return cm.Marshal()
   290  	}
   291  	cm := new(ipv4.ControlMessage)
   292  	cm.Src = src.IP
   293  	return cm.Marshal()
   294  }
   295  
   296  // computeIPv6Checksum computes and returns a checksum from the given src/dest IPs
   297  // and UDP header with a payload.
   298  func computeIPv6Checksum(srcIP, dstIP net.IP, udpHeaderWithPayload []byte) uint16 {
   299  	pseudoHeader := genIPv6PseudoHeader(srcIP, dstIP, len(udpHeaderWithPayload))
   300  	packet := append(pseudoHeader, udpHeaderWithPayload...)
   301  	checksum := computeChecksum(packet)
   302  	return checksum
   303  }
   304  
   305  // genIPv6PseudoHeader generates and returns an IPv6 pseudo-header used for calculating
   306  // the checksum of a UDP packet.
   307  func genIPv6PseudoHeader(srcIP, dstIP net.IP, headerAndPayloadSize int) []byte {
   308  	header := make([]byte, pseudoHeaderLength)
   309  	// Source address
   310  	copy(header[0:], srcIP)
   311  	// Destination address
   312  	copy(header[16:], dstIP)
   313  	// Payload length (16-bit field)
   314  	binary.BigEndian.PutUint16(header[32:34], uint16(headerAndPayloadSize))
   315  	if headerAndPayloadSize != 0 {
   316  		// Next header (UDP)
   317  		header[39] = 0x11
   318  	}
   319  	return header
   320  }
   321  
   322  // computeChecksum computes and returns a checksum for the given packet represented as
   323  // a byte slice.
   324  func computeChecksum(packet []byte) uint16 {
   325  	sum := uint32(0)
   326  
   327  	for ; len(packet) >= 2; packet = packet[2:] {
   328  		sum += uint32(packet[0])<<8 | uint32(packet[1])
   329  	}
   330  	if len(packet) > 0 {
   331  		sum += uint32(packet[0]) << 8
   332  	}
   333  	for sum > 0xffff {
   334  		sum = (sum >> 16) + (sum & 0xffff)
   335  	}
   336  	csum := ^uint16(sum)
   337  	/*
   338  	 * From RFC 768:
   339  	 * If the computed checksum is zero, it is transmitted as all ones (the
   340  	 * equivalent in one's complement arithmetic). An all zero transmitted
   341  	 * checksum value means that the transmitter generated no checksum (for
   342  	 * debugging or for higher level protocols that don't care).
   343  	 */
   344  	if csum == 0 {
   345  		csum = 0xffff
   346  	}
   347  	return csum
   348  }