github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/conn/bind_std.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package conn
    33  
    34  import (
    35  	"context"
    36  	"errors"
    37  	"fmt"
    38  	"net"
    39  	"net/netip"
    40  	"runtime"
    41  	"strconv"
    42  	"sync"
    43  	"syscall"
    44  
    45  	"golang.org/x/net/ipv4"
    46  	"golang.org/x/net/ipv6"
    47  )
    48  
    49  var (
    50  	_ Bind = (*StdNetBind)(nil)
    51  )
    52  
    53  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
    54  // (see bind_windows.go), it may fall back to StdNetBind.
    55  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
    56  // methods for sending and receiving multiple datagrams per-syscall. See the
    57  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
    58  type StdNetBind struct {
    59  	mu            sync.Mutex // protects all fields except as specified
    60  	ipv4          *net.UDPConn
    61  	ipv6          *net.UDPConn
    62  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
    63  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
    64  	ipv4TxOffload bool
    65  	ipv4RxOffload bool
    66  	ipv6TxOffload bool
    67  	ipv6RxOffload bool
    68  
    69  	// these two fields are not guarded by mu
    70  	udpAddrPool sync.Pool
    71  	msgsPool    sync.Pool
    72  
    73  	blackhole4 bool
    74  	blackhole6 bool
    75  }
    76  
    77  func NewStdNetBind() Bind {
    78  	return &StdNetBind{
    79  		udpAddrPool: sync.Pool{
    80  			New: func() any {
    81  				return &net.UDPAddr{
    82  					IP: make([]byte, 16),
    83  				}
    84  			},
    85  		},
    86  
    87  		msgsPool: sync.Pool{
    88  			New: func() any {
    89  				// ipv6.Message and ipv4.Message are interchangeable as they are
    90  				// both aliases for x/net/internal/socket.Message.
    91  				msgs := make([]ipv6.Message, IdealBatchSize)
    92  				for i := range msgs {
    93  					msgs[i].Buffers = make(net.Buffers, 1)
    94  					msgs[i].OOB = make([]byte, 0, gsoControlSize)
    95  				}
    96  				return &msgs
    97  			},
    98  		},
    99  	}
   100  }
   101  
   102  type StdNetEndpoint struct {
   103  	// AddrPort is the endpoint destination.
   104  	netip.AddrPort
   105  }
   106  
   107  var (
   108  	_ Bind     = (*StdNetBind)(nil)
   109  	_ Endpoint = &StdNetEndpoint{}
   110  )
   111  
   112  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
   113  	e, err := netip.ParseAddrPort(s)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	return &StdNetEndpoint{
   118  		AddrPort: e,
   119  	}, nil
   120  }
   121  
   122  func (e *StdNetEndpoint) DstIP() netip.Addr {
   123  	return e.AddrPort.Addr()
   124  }
   125  
   126  // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
   127  
   128  func (e *StdNetEndpoint) DstToBytes() []byte {
   129  	b, _ := e.AddrPort.MarshalBinary()
   130  	return b
   131  }
   132  
   133  func (e *StdNetEndpoint) DstToString() string {
   134  	return e.AddrPort.String()
   135  }
   136  
   137  func listenNet(network string, port int) (*net.UDPConn, int, error) {
   138  	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
   139  	if err != nil {
   140  		return nil, 0, err
   141  	}
   142  
   143  	// Retrieve port.
   144  	laddr := conn.LocalAddr()
   145  	uaddr, err := net.ResolveUDPAddr(
   146  		laddr.Network(),
   147  		laddr.String(),
   148  	)
   149  	if err != nil {
   150  		return nil, 0, err
   151  	}
   152  	return conn.(*net.UDPConn), uaddr.Port, nil
   153  }
   154  
   155  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
   156  	s.mu.Lock()
   157  	defer s.mu.Unlock()
   158  
   159  	var err error
   160  	var tries int
   161  
   162  	if s.ipv4 != nil || s.ipv6 != nil {
   163  		return nil, 0, ErrBindAlreadyOpen
   164  	}
   165  
   166  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   167  	// If uport is 0, we can retry on failure.
   168  again:
   169  	port := int(uport)
   170  	var v4conn, v6conn *net.UDPConn
   171  	var v4pc *ipv4.PacketConn
   172  	var v6pc *ipv6.PacketConn
   173  
   174  	v4conn, port, err = listenNet("udp4", port)
   175  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   176  		return nil, 0, err
   177  	}
   178  
   179  	// Listen on the same port as we're using for ipv4.
   180  	v6conn, port, err = listenNet("udp6", port)
   181  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   182  		v4conn.Close()
   183  		tries++
   184  		goto again
   185  	}
   186  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   187  		v4conn.Close()
   188  		return nil, 0, err
   189  	}
   190  	var fns []ReceiveFunc
   191  	if v4conn != nil {
   192  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
   193  		if runtime.GOOS == "linux" {
   194  			v4pc = ipv4.NewPacketConn(v4conn)
   195  			s.ipv4PC = v4pc
   196  		}
   197  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
   198  		s.ipv4 = v4conn
   199  	}
   200  	if v6conn != nil {
   201  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
   202  		if runtime.GOOS == "linux" {
   203  			v6pc = ipv6.NewPacketConn(v6conn)
   204  			s.ipv6PC = v6pc
   205  		}
   206  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
   207  		s.ipv6 = v6conn
   208  	}
   209  	if len(fns) == 0 {
   210  		return nil, 0, syscall.EAFNOSUPPORT
   211  	}
   212  
   213  	return fns, uint16(port), nil
   214  }
   215  
   216  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
   217  	for i := range *msgs {
   218  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
   219  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
   220  	}
   221  	s.msgsPool.Put(msgs)
   222  }
   223  
   224  func (s *StdNetBind) getMessages() *[]ipv6.Message {
   225  	return s.msgsPool.Get().(*[]ipv6.Message)
   226  }
   227  
   228  var (
   229  	// If compilation fails here these are no longer the same underlying type.
   230  	_ ipv6.Message = ipv4.Message{}
   231  )
   232  
   233  type batchReader interface {
   234  	ReadBatch([]ipv6.Message, int) (int, error)
   235  }
   236  
   237  type batchWriter interface {
   238  	WriteBatch([]ipv6.Message, int) (int, error)
   239  }
   240  
   241  func (s *StdNetBind) receiveIP(
   242  	br batchReader,
   243  	conn *net.UDPConn,
   244  	rxOffload bool,
   245  	bufs [][]byte,
   246  	sizes []int,
   247  	eps []Endpoint,
   248  ) (n int, err error) {
   249  	msgs := s.getMessages()
   250  	for i := range bufs {
   251  		(*msgs)[i].Buffers[0] = bufs[i]
   252  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
   253  	}
   254  	defer s.putMessages(msgs)
   255  	var numMsgs int
   256  	if runtime.GOOS == "linux" {
   257  		if rxOffload {
   258  			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
   259  			_, err = br.ReadBatch((*msgs)[readAt:], 0)
   260  			if err != nil {
   261  				return 0, err
   262  			}
   263  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
   264  			if err != nil {
   265  				return 0, err
   266  			}
   267  		} else {
   268  			numMsgs, err = br.ReadBatch(*msgs, 0)
   269  			if err != nil {
   270  				return 0, err
   271  			}
   272  		}
   273  	} else {
   274  		msg := &(*msgs)[0]
   275  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   276  		if err != nil {
   277  			return 0, err
   278  		}
   279  		numMsgs = 1
   280  	}
   281  	for i := 0; i < numMsgs; i++ {
   282  		msg := &(*msgs)[i]
   283  		sizes[i] = msg.N
   284  		if sizes[i] == 0 {
   285  			continue
   286  		}
   287  		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   288  		ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   289  		eps[i] = ep
   290  	}
   291  	return numMsgs, nil
   292  }
   293  
   294  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   295  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   296  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   297  	}
   298  }
   299  
   300  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   301  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   302  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   303  	}
   304  }
   305  
   306  func (s *StdNetBind) BatchSize() int {
   307  	if runtime.GOOS == "linux" {
   308  		return IdealBatchSize
   309  	}
   310  	return 1
   311  }
   312  
   313  func (s *StdNetBind) Close() error {
   314  	s.mu.Lock()
   315  	defer s.mu.Unlock()
   316  
   317  	var err1, err2 error
   318  	if s.ipv4 != nil {
   319  		err1 = s.ipv4.Close()
   320  		s.ipv4 = nil
   321  		s.ipv4PC = nil
   322  	}
   323  	if s.ipv6 != nil {
   324  		err2 = s.ipv6.Close()
   325  		s.ipv6 = nil
   326  		s.ipv6PC = nil
   327  	}
   328  	s.blackhole4 = false
   329  	s.blackhole6 = false
   330  	s.ipv4TxOffload = false
   331  	s.ipv4RxOffload = false
   332  	s.ipv6TxOffload = false
   333  	s.ipv6RxOffload = false
   334  	if err1 != nil {
   335  		return err1
   336  	}
   337  	return err2
   338  }
   339  
   340  type ErrUDPGSODisabled struct {
   341  	onLaddr  string
   342  	RetryErr error
   343  }
   344  
   345  func (e ErrUDPGSODisabled) Error() string {
   346  	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
   347  }
   348  
   349  func (e ErrUDPGSODisabled) Unwrap() error {
   350  	return e.RetryErr
   351  }
   352  
   353  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
   354  	s.mu.Lock()
   355  	blackhole := s.blackhole4
   356  	conn := s.ipv4
   357  	offload := s.ipv4TxOffload
   358  	br := batchWriter(s.ipv4PC)
   359  	is6 := false
   360  	if endpoint.DstIP().Is6() {
   361  		blackhole = s.blackhole6
   362  		conn = s.ipv6
   363  		br = s.ipv6PC
   364  		is6 = true
   365  		offload = s.ipv6TxOffload
   366  	}
   367  	s.mu.Unlock()
   368  
   369  	if blackhole {
   370  		return nil
   371  	}
   372  	if conn == nil {
   373  		return syscall.EAFNOSUPPORT
   374  	}
   375  
   376  	msgs := s.getMessages()
   377  	defer s.putMessages(msgs)
   378  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   379  	defer s.udpAddrPool.Put(ua)
   380  	if is6 {
   381  		as16 := endpoint.DstIP().As16()
   382  		copy(ua.IP, as16[:])
   383  		ua.IP = ua.IP[:16]
   384  	} else {
   385  		as4 := endpoint.DstIP().As4()
   386  		copy(ua.IP, as4[:])
   387  		ua.IP = ua.IP[:4]
   388  	}
   389  	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
   390  	var (
   391  		retried bool
   392  		err     error
   393  	)
   394  retry:
   395  	if offload {
   396  		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
   397  		err = s.send(conn, br, (*msgs)[:n])
   398  		if err != nil && offload && errShouldDisableUDPGSO(err) {
   399  			offload = false
   400  			s.mu.Lock()
   401  			if is6 {
   402  				s.ipv6TxOffload = false
   403  			} else {
   404  				s.ipv4TxOffload = false
   405  			}
   406  			s.mu.Unlock()
   407  			retried = true
   408  			goto retry
   409  		}
   410  	} else {
   411  		for i := range bufs {
   412  			(*msgs)[i].Addr = ua
   413  			(*msgs)[i].Buffers[0] = bufs[i]
   414  		}
   415  		err = s.send(conn, br, (*msgs)[:len(bufs)])
   416  	}
   417  	if retried {
   418  		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
   419  	}
   420  	return err
   421  }
   422  
   423  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
   424  	var (
   425  		n     int
   426  		err   error
   427  		start int
   428  	)
   429  	if runtime.GOOS == "linux" {
   430  		for {
   431  			n, err = pc.WriteBatch(msgs[start:], 0)
   432  			if err != nil || n == len(msgs[start:]) {
   433  				break
   434  			}
   435  			start += n
   436  		}
   437  	} else {
   438  		for _, msg := range msgs {
   439  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
   440  			if err != nil {
   441  				break
   442  			}
   443  		}
   444  	}
   445  	return err
   446  }
   447  
   448  const (
   449  	// Exceeding these values results in EMSGSIZE. They account for layer3 and
   450  	// layer4 headers. IPv6 does not need to account for itself as the payload
   451  	// length field is self excluding.
   452  	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
   453  	maxIPv6PayloadLen = 1<<16 - 1 - 8
   454  
   455  	// This is a hard limit imposed by the kernel.
   456  	udpSegmentMaxDatagrams = 64
   457  )
   458  
   459  type setGSOFunc func(control *[]byte, gsoSize uint16)
   460  
   461  func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
   462  	var (
   463  		base     = -1 // index of msg we are currently coalescing into
   464  		gsoSize  int  // segmentation size of msgs[base]
   465  		dgramCnt int  // number of dgrams coalesced into msgs[base]
   466  		endBatch bool // tracking flag to start a new batch on next iteration of bufs
   467  	)
   468  	maxPayloadLen := maxIPv4PayloadLen
   469  	if ep.DstIP().Is6() {
   470  		maxPayloadLen = maxIPv6PayloadLen
   471  	}
   472  	for i, buf := range bufs {
   473  		if i > 0 {
   474  			msgLen := len(buf)
   475  			baseLenBefore := len(msgs[base].Buffers[0])
   476  			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
   477  			if msgLen+baseLenBefore <= maxPayloadLen &&
   478  				msgLen <= gsoSize &&
   479  				msgLen <= freeBaseCap &&
   480  				dgramCnt < udpSegmentMaxDatagrams &&
   481  				!endBatch {
   482  				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
   483  				if i == len(bufs)-1 {
   484  					setGSO(&msgs[base].OOB, uint16(gsoSize))
   485  				}
   486  				dgramCnt++
   487  				if msgLen < gsoSize {
   488  					// A smaller than gsoSize packet on the tail is legal, but
   489  					// it must end the batch.
   490  					endBatch = true
   491  				}
   492  				continue
   493  			}
   494  		}
   495  		if dgramCnt > 1 {
   496  			setGSO(&msgs[base].OOB, uint16(gsoSize))
   497  		}
   498  		// Reset prior to incrementing base since we are preparing to start a
   499  		// new potential batch.
   500  		endBatch = false
   501  		base++
   502  		gsoSize = len(buf)
   503  		msgs[base].Buffers[0] = buf
   504  		msgs[base].Addr = addr
   505  		dgramCnt = 1
   506  	}
   507  	return base + 1
   508  }
   509  
   510  type getGSOFunc func(control []byte) (int, error)
   511  
   512  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
   513  	for i := firstMsgAt; i < len(msgs); i++ {
   514  		msg := &msgs[i]
   515  		if msg.N == 0 {
   516  			return n, err
   517  		}
   518  		var (
   519  			gsoSize    int
   520  			start      int
   521  			end        = msg.N
   522  			numToSplit = 1
   523  		)
   524  		gsoSize, err = getGSO(msg.OOB[:msg.NN])
   525  		if err != nil {
   526  			return n, err
   527  		}
   528  		if gsoSize > 0 {
   529  			numToSplit = (msg.N + gsoSize - 1) / gsoSize
   530  			end = gsoSize
   531  		}
   532  		for j := 0; j < numToSplit; j++ {
   533  			if n > i {
   534  				return n, errors.New("splitting coalesced packet resulted in overflow")
   535  			}
   536  			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
   537  			msgs[n].N = copied
   538  			msgs[n].Addr = msg.Addr
   539  			start = end
   540  			end += gsoSize
   541  			if end > msg.N {
   542  				end = msg.N
   543  			}
   544  			n++
   545  		}
   546  		if i != n-1 {
   547  			// It is legal for bytes to move within msg.Buffers[0] as a result
   548  			// of splitting, so we only zero the source msg len when it is not
   549  			// the destination of the last split operation above.
   550  			msg.N = 0
   551  		}
   552  	}
   553  	return n, nil
   554  }