github.com/amnezia-vpn/amnezia-wg@v0.1.8/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"net/netip"
    14  	"runtime"
    15  	"strconv"
    16  	"sync"
    17  	"syscall"
    18  
    19  	"golang.org/x/net/ipv4"
    20  	"golang.org/x/net/ipv6"
    21  )
    22  
    23  var (
    24  	_ Bind = (*StdNetBind)(nil)
    25  )
    26  
    27  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
    28  // (see bind_windows.go), it may fall back to StdNetBind.
    29  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
    30  // methods for sending and receiving multiple datagrams per-syscall. See the
    31  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
    32  type StdNetBind struct {
    33  	mu            sync.Mutex // protects all fields except as specified
    34  	ipv4          *net.UDPConn
    35  	ipv6          *net.UDPConn
    36  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
    37  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
    38  	ipv4TxOffload bool
    39  	ipv4RxOffload bool
    40  	ipv6TxOffload bool
    41  	ipv6RxOffload bool
    42  
    43  	// these two fields are not guarded by mu
    44  	udpAddrPool sync.Pool
    45  	msgsPool    sync.Pool
    46  
    47  	blackhole4 bool
    48  	blackhole6 bool
    49  }
    50  
    51  func NewStdNetBind() Bind {
    52  	return &StdNetBind{
    53  		udpAddrPool: sync.Pool{
    54  			New: func() any {
    55  				return &net.UDPAddr{
    56  					IP: make([]byte, 16),
    57  				}
    58  			},
    59  		},
    60  
    61  		msgsPool: sync.Pool{
    62  			New: func() any {
    63  				// ipv6.Message and ipv4.Message are interchangeable as they are
    64  				// both aliases for x/net/internal/socket.Message.
    65  				msgs := make([]ipv6.Message, IdealBatchSize)
    66  				for i := range msgs {
    67  					msgs[i].Buffers = make(net.Buffers, 1)
    68  					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
    69  				}
    70  				return &msgs
    71  			},
    72  		},
    73  	}
    74  }
    75  
    76  type StdNetEndpoint struct {
    77  	// AddrPort is the endpoint destination.
    78  	netip.AddrPort
    79  	// src is the current sticky source address and interface index, if
    80  	// supported. Typically this is a PKTINFO structure from/for control
    81  	// messages, see unix.PKTINFO for an example.
    82  	src []byte
    83  }
    84  
    85  var (
    86  	_ Bind     = (*StdNetBind)(nil)
    87  	_ Endpoint = &StdNetEndpoint{}
    88  )
    89  
    90  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    91  	e, err := netip.ParseAddrPort(s)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	return &StdNetEndpoint{
    96  		AddrPort: e,
    97  	}, nil
    98  }
    99  
   100  func (e *StdNetEndpoint) ClearSrc() {
   101  	if e.src != nil {
   102  		// Truncate src, no need to reallocate.
   103  		e.src = e.src[:0]
   104  	}
   105  }
   106  
   107  func (e *StdNetEndpoint) DstIP() netip.Addr {
   108  	return e.AddrPort.Addr()
   109  }
   110  
   111  // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
   112  
   113  func (e *StdNetEndpoint) DstToBytes() []byte {
   114  	b, _ := e.AddrPort.MarshalBinary()
   115  	return b
   116  }
   117  
   118  func (e *StdNetEndpoint) DstToString() string {
   119  	return e.AddrPort.String()
   120  }
   121  
   122  func listenNet(network string, port int) (*net.UDPConn, int, error) {
   123  	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
   124  	if err != nil {
   125  		return nil, 0, err
   126  	}
   127  
   128  	// Retrieve port.
   129  	laddr := conn.LocalAddr()
   130  	uaddr, err := net.ResolveUDPAddr(
   131  		laddr.Network(),
   132  		laddr.String(),
   133  	)
   134  	if err != nil {
   135  		return nil, 0, err
   136  	}
   137  	return conn.(*net.UDPConn), uaddr.Port, nil
   138  }
   139  
   140  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
   141  	s.mu.Lock()
   142  	defer s.mu.Unlock()
   143  
   144  	var err error
   145  	var tries int
   146  
   147  	if s.ipv4 != nil || s.ipv6 != nil {
   148  		return nil, 0, ErrBindAlreadyOpen
   149  	}
   150  
   151  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   152  	// If uport is 0, we can retry on failure.
   153  again:
   154  	port := int(uport)
   155  	var v4conn, v6conn *net.UDPConn
   156  	var v4pc *ipv4.PacketConn
   157  	var v6pc *ipv6.PacketConn
   158  
   159  	v4conn, port, err = listenNet("udp4", port)
   160  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   161  		return nil, 0, err
   162  	}
   163  
   164  	// Listen on the same port as we're using for ipv4.
   165  	v6conn, port, err = listenNet("udp6", port)
   166  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   167  		v4conn.Close()
   168  		tries++
   169  		goto again
   170  	}
   171  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   172  		v4conn.Close()
   173  		return nil, 0, err
   174  	}
   175  	var fns []ReceiveFunc
   176  	if v4conn != nil {
   177  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
   178  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   179  			v4pc = ipv4.NewPacketConn(v4conn)
   180  			s.ipv4PC = v4pc
   181  		}
   182  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
   183  		s.ipv4 = v4conn
   184  	}
   185  	if v6conn != nil {
   186  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
   187  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   188  			v6pc = ipv6.NewPacketConn(v6conn)
   189  			s.ipv6PC = v6pc
   190  		}
   191  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
   192  		s.ipv6 = v6conn
   193  	}
   194  	if len(fns) == 0 {
   195  		return nil, 0, syscall.EAFNOSUPPORT
   196  	}
   197  
   198  	return fns, uint16(port), nil
   199  }
   200  
   201  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
   202  	for i := range *msgs {
   203  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
   204  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
   205  	}
   206  	s.msgsPool.Put(msgs)
   207  }
   208  
   209  func (s *StdNetBind) getMessages() *[]ipv6.Message {
   210  	return s.msgsPool.Get().(*[]ipv6.Message)
   211  }
   212  
   213  var (
   214  	// If compilation fails here these are no longer the same underlying type.
   215  	_ ipv6.Message = ipv4.Message{}
   216  )
   217  
   218  type batchReader interface {
   219  	ReadBatch([]ipv6.Message, int) (int, error)
   220  }
   221  
   222  type batchWriter interface {
   223  	WriteBatch([]ipv6.Message, int) (int, error)
   224  }
   225  
   226  func (s *StdNetBind) receiveIP(
   227  	br batchReader,
   228  	conn *net.UDPConn,
   229  	rxOffload bool,
   230  	bufs [][]byte,
   231  	sizes []int,
   232  	eps []Endpoint,
   233  ) (n int, err error) {
   234  	msgs := s.getMessages()
   235  	for i := range bufs {
   236  		(*msgs)[i].Buffers[0] = bufs[i]
   237  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
   238  	}
   239  	defer s.putMessages(msgs)
   240  	var numMsgs int
   241  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   242  		if rxOffload {
   243  			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
   244  			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
   245  			if err != nil {
   246  				return 0, err
   247  			}
   248  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
   249  			if err != nil {
   250  				return 0, err
   251  			}
   252  		} else {
   253  			numMsgs, err = br.ReadBatch(*msgs, 0)
   254  			if err != nil {
   255  				return 0, err
   256  			}
   257  		}
   258  	} else {
   259  		msg := &(*msgs)[0]
   260  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   261  		if err != nil {
   262  			return 0, err
   263  		}
   264  		numMsgs = 1
   265  	}
   266  	for i := 0; i < numMsgs; i++ {
   267  		msg := &(*msgs)[i]
   268  		sizes[i] = msg.N
   269  		if sizes[i] == 0 {
   270  			continue
   271  		}
   272  		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   273  		ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   274  		getSrcFromControl(msg.OOB[:msg.NN], ep)
   275  		eps[i] = ep
   276  	}
   277  	return numMsgs, nil
   278  }
   279  
   280  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   281  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   282  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   283  	}
   284  }
   285  
   286  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   287  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   288  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   289  	}
   290  }
   291  
   292  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   293  // rename the IdealBatchSize constant to BatchSize.
   294  func (s *StdNetBind) BatchSize() int {
   295  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   296  		return IdealBatchSize
   297  	}
   298  	return 1
   299  }
   300  
   301  func (s *StdNetBind) Close() error {
   302  	s.mu.Lock()
   303  	defer s.mu.Unlock()
   304  
   305  	var err1, err2 error
   306  	if s.ipv4 != nil {
   307  		err1 = s.ipv4.Close()
   308  		s.ipv4 = nil
   309  		s.ipv4PC = nil
   310  	}
   311  	if s.ipv6 != nil {
   312  		err2 = s.ipv6.Close()
   313  		s.ipv6 = nil
   314  		s.ipv6PC = nil
   315  	}
   316  	s.blackhole4 = false
   317  	s.blackhole6 = false
   318  	s.ipv4TxOffload = false
   319  	s.ipv4RxOffload = false
   320  	s.ipv6TxOffload = false
   321  	s.ipv6RxOffload = false
   322  	if err1 != nil {
   323  		return err1
   324  	}
   325  	return err2
   326  }
   327  
   328  type ErrUDPGSODisabled struct {
   329  	onLaddr  string
   330  	RetryErr error
   331  }
   332  
   333  func (e ErrUDPGSODisabled) Error() string {
   334  	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
   335  }
   336  
   337  func (e ErrUDPGSODisabled) Unwrap() error {
   338  	return e.RetryErr
   339  }
   340  
   341  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
   342  	s.mu.Lock()
   343  	blackhole := s.blackhole4
   344  	conn := s.ipv4
   345  	offload := s.ipv4TxOffload
   346  	br := batchWriter(s.ipv4PC)
   347  	is6 := false
   348  	if endpoint.DstIP().Is6() {
   349  		blackhole = s.blackhole6
   350  		conn = s.ipv6
   351  		br = s.ipv6PC
   352  		is6 = true
   353  		offload = s.ipv6TxOffload
   354  	}
   355  	s.mu.Unlock()
   356  
   357  	if blackhole {
   358  		return nil
   359  	}
   360  	if conn == nil {
   361  		return syscall.EAFNOSUPPORT
   362  	}
   363  
   364  	msgs := s.getMessages()
   365  	defer s.putMessages(msgs)
   366  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   367  	defer s.udpAddrPool.Put(ua)
   368  	if is6 {
   369  		as16 := endpoint.DstIP().As16()
   370  		copy(ua.IP, as16[:])
   371  		ua.IP = ua.IP[:16]
   372  	} else {
   373  		as4 := endpoint.DstIP().As4()
   374  		copy(ua.IP, as4[:])
   375  		ua.IP = ua.IP[:4]
   376  	}
   377  	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
   378  	var (
   379  		retried bool
   380  		err     error
   381  	)
   382  retry:
   383  	if offload {
   384  		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
   385  		err = s.send(conn, br, (*msgs)[:n])
   386  		if err != nil && offload && errShouldDisableUDPGSO(err) {
   387  			offload = false
   388  			s.mu.Lock()
   389  			if is6 {
   390  				s.ipv6TxOffload = false
   391  			} else {
   392  				s.ipv4TxOffload = false
   393  			}
   394  			s.mu.Unlock()
   395  			retried = true
   396  			goto retry
   397  		}
   398  	} else {
   399  		for i := range bufs {
   400  			(*msgs)[i].Addr = ua
   401  			(*msgs)[i].Buffers[0] = bufs[i]
   402  			setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
   403  		}
   404  		err = s.send(conn, br, (*msgs)[:len(bufs)])
   405  	}
   406  	if retried {
   407  		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
   408  	}
   409  	return err
   410  }
   411  
   412  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
   413  	var (
   414  		n     int
   415  		err   error
   416  		start int
   417  	)
   418  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   419  		for {
   420  			n, err = pc.WriteBatch(msgs[start:], 0)
   421  			if err != nil || n == len(msgs[start:]) {
   422  				break
   423  			}
   424  			start += n
   425  		}
   426  	} else {
   427  		for _, msg := range msgs {
   428  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
   429  			if err != nil {
   430  				break
   431  			}
   432  		}
   433  	}
   434  	return err
   435  }
   436  
   437  const (
   438  	// Exceeding these values results in EMSGSIZE. They account for layer3 and
   439  	// layer4 headers. IPv6 does not need to account for itself as the payload
   440  	// length field is self excluding.
   441  	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
   442  	maxIPv6PayloadLen = 1<<16 - 1 - 8
   443  
   444  	// This is a hard limit imposed by the kernel.
   445  	udpSegmentMaxDatagrams = 64
   446  )
   447  
   448  type setGSOFunc func(control *[]byte, gsoSize uint16)
   449  
   450  func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
   451  	var (
   452  		base     = -1 // index of msg we are currently coalescing into
   453  		gsoSize  int  // segmentation size of msgs[base]
   454  		dgramCnt int  // number of dgrams coalesced into msgs[base]
   455  		endBatch bool // tracking flag to start a new batch on next iteration of bufs
   456  	)
   457  	maxPayloadLen := maxIPv4PayloadLen
   458  	if ep.DstIP().Is6() {
   459  		maxPayloadLen = maxIPv6PayloadLen
   460  	}
   461  	for i, buf := range bufs {
   462  		if i > 0 {
   463  			msgLen := len(buf)
   464  			baseLenBefore := len(msgs[base].Buffers[0])
   465  			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
   466  			if msgLen+baseLenBefore <= maxPayloadLen &&
   467  				msgLen <= gsoSize &&
   468  				msgLen <= freeBaseCap &&
   469  				dgramCnt < udpSegmentMaxDatagrams &&
   470  				!endBatch {
   471  				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
   472  				if i == len(bufs)-1 {
   473  					setGSO(&msgs[base].OOB, uint16(gsoSize))
   474  				}
   475  				dgramCnt++
   476  				if msgLen < gsoSize {
   477  					// A smaller than gsoSize packet on the tail is legal, but
   478  					// it must end the batch.
   479  					endBatch = true
   480  				}
   481  				continue
   482  			}
   483  		}
   484  		if dgramCnt > 1 {
   485  			setGSO(&msgs[base].OOB, uint16(gsoSize))
   486  		}
   487  		// Reset prior to incrementing base since we are preparing to start a
   488  		// new potential batch.
   489  		endBatch = false
   490  		base++
   491  		gsoSize = len(buf)
   492  		setSrcControl(&msgs[base].OOB, ep)
   493  		msgs[base].Buffers[0] = buf
   494  		msgs[base].Addr = addr
   495  		dgramCnt = 1
   496  	}
   497  	return base + 1
   498  }
   499  
   500  type getGSOFunc func(control []byte) (int, error)
   501  
   502  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
   503  	for i := firstMsgAt; i < len(msgs); i++ {
   504  		msg := &msgs[i]
   505  		if msg.N == 0 {
   506  			return n, err
   507  		}
   508  		var (
   509  			gsoSize    int
   510  			start      int
   511  			end        = msg.N
   512  			numToSplit = 1
   513  		)
   514  		gsoSize, err = getGSO(msg.OOB[:msg.NN])
   515  		if err != nil {
   516  			return n, err
   517  		}
   518  		if gsoSize > 0 {
   519  			numToSplit = (msg.N + gsoSize - 1) / gsoSize
   520  			end = gsoSize
   521  		}
   522  		for j := 0; j < numToSplit; j++ {
   523  			if n > i {
   524  				return n, errors.New("splitting coalesced packet resulted in overflow")
   525  			}
   526  			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
   527  			msgs[n].N = copied
   528  			msgs[n].Addr = msg.Addr
   529  			start = end
   530  			end += gsoSize
   531  			if end > msg.N {
   532  				end = msg.N
   533  			}
   534  			n++
   535  		}
   536  		if i != n-1 {
   537  			// It is legal for bytes to move within msg.Buffers[0] as a result
   538  			// of splitting, so we only zero the source msg len when it is not
   539  			// the destination of the last split operation above.
   540  			msg.N = 0
   541  		}
   542  	}
   543  	return n, nil
   544  }