github.com/amnezia-vpn/amneziawg-go@v0.2.8/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"net/netip"
    14  	"runtime"
    15  	"strconv"
    16  	"sync"
    17  	"syscall"
    18  
    19  	"golang.org/x/net/ipv4"
    20  	"golang.org/x/net/ipv6"
    21  )
    22  
    23  var (
    24  	_ Bind = (*StdNetBind)(nil)
    25  )
    26  
    27  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
    28  // (see bind_windows.go), it may fall back to StdNetBind.
    29  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
    30  // methods for sending and receiving multiple datagrams per-syscall. See the
    31  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
    32  type StdNetBind struct {
    33  	mu            sync.Mutex // protects all fields except as specified
    34  	ipv4          *net.UDPConn
    35  	ipv6          *net.UDPConn
    36  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
    37  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
    38  	ipv4TxOffload bool
    39  	ipv4RxOffload bool
    40  	ipv6TxOffload bool
    41  	ipv6RxOffload bool
    42  
    43  	// these two fields are not guarded by mu
    44  	udpAddrPool sync.Pool
    45  	msgsPool    sync.Pool
    46  
    47  	blackhole4 bool
    48  	blackhole6 bool
    49  }
    50  
    51  func NewStdNetBind() Bind {
    52  	return &StdNetBind{
    53  		udpAddrPool: sync.Pool{
    54  			New: func() any {
    55  				return &net.UDPAddr{
    56  					IP: make([]byte, 16),
    57  				}
    58  			},
    59  		},
    60  
    61  		msgsPool: sync.Pool{
    62  			New: func() any {
    63  				// ipv6.Message and ipv4.Message are interchangeable as they are
    64  				// both aliases for x/net/internal/socket.Message.
    65  				msgs := make([]ipv6.Message, IdealBatchSize)
    66  				for i := range msgs {
    67  					msgs[i].Buffers = make(net.Buffers, 1)
    68  					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
    69  				}
    70  				return &msgs
    71  			},
    72  		},
    73  	}
    74  }
    75  
    76  type StdNetEndpoint struct {
    77  	// AddrPort is the endpoint destination.
    78  	netip.AddrPort
    79  	// src is the current sticky source address and interface index, if
    80  	// supported. Typically this is a PKTINFO structure from/for control
    81  	// messages, see unix.PKTINFO for an example.
    82  	src []byte
    83  }
    84  
    85  var (
    86  	_ Bind     = (*StdNetBind)(nil)
    87  	_ Endpoint = &StdNetEndpoint{}
    88  )
    89  
    90  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    91  	e, err := netip.ParseAddrPort(s)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	return &StdNetEndpoint{
    96  		AddrPort: e,
    97  	}, nil
    98  }
    99  
   100  func (e *StdNetEndpoint) ClearSrc() {
   101  	if e.src != nil {
   102  		// Truncate src, no need to reallocate.
   103  		e.src = e.src[:0]
   104  	}
   105  }
   106  
   107  func (e *StdNetEndpoint) DstIP() netip.Addr {
   108  	return e.AddrPort.Addr()
   109  }
   110  
   111  // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
   112  
   113  func (e *StdNetEndpoint) DstToBytes() []byte {
   114  	b, _ := e.AddrPort.MarshalBinary()
   115  	return b
   116  }
   117  
   118  func (e *StdNetEndpoint) DstToString() string {
   119  	return e.AddrPort.String()
   120  }
   121  
   122  func listenNet(network string, port int) (*net.UDPConn, int, error) {
   123  	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
   124  	if err != nil {
   125  		return nil, 0, err
   126  	}
   127  
   128  	// Retrieve port.
   129  	laddr := conn.LocalAddr()
   130  	uaddr, err := net.ResolveUDPAddr(
   131  		laddr.Network(),
   132  		laddr.String(),
   133  	)
   134  	if err != nil {
   135  		return nil, 0, err
   136  	}
   137  	return conn.(*net.UDPConn), uaddr.Port, nil
   138  }
   139  
   140  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
   141  	s.mu.Lock()
   142  	defer s.mu.Unlock()
   143  
   144  	var err error
   145  	var tries int
   146  
   147  	if s.ipv4 != nil || s.ipv6 != nil {
   148  		return nil, 0, ErrBindAlreadyOpen
   149  	}
   150  
   151  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   152  	// If uport is 0, we can retry on failure.
   153  again:
   154  	port := int(uport)
   155  	var v4conn, v6conn *net.UDPConn
   156  	var v4pc *ipv4.PacketConn
   157  	var v6pc *ipv6.PacketConn
   158  
   159  	v4conn, port, err = listenNet("udp4", port)
   160  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   161  		return nil, 0, err
   162  	}
   163  
   164  	// Listen on the same port as we're using for ipv4.
   165  	v6conn, port, err = listenNet("udp6", port)
   166  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   167  		v4conn.Close()
   168  		tries++
   169  		goto again
   170  	}
   171  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   172  		v4conn.Close()
   173  		return nil, 0, err
   174  	}
   175  	var fns []ReceiveFunc
   176  	if v4conn != nil {
   177  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
   178  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   179  			v4pc = ipv4.NewPacketConn(v4conn)
   180  			s.ipv4PC = v4pc
   181  		}
   182  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
   183  		s.ipv4 = v4conn
   184  	}
   185  	if v6conn != nil {
   186  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
   187  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   188  			v6pc = ipv6.NewPacketConn(v6conn)
   189  			s.ipv6PC = v6pc
   190  		}
   191  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
   192  		s.ipv6 = v6conn
   193  	}
   194  	if len(fns) == 0 {
   195  		return nil, 0, syscall.EAFNOSUPPORT
   196  	}
   197  
   198  	return fns, uint16(port), nil
   199  }
   200  
   201  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
   202  	for i := range *msgs {
   203  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
   204  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
   205  	}
   206  	s.msgsPool.Put(msgs)
   207  }
   208  
   209  func (s *StdNetBind) getMessages() *[]ipv6.Message {
   210  	return s.msgsPool.Get().(*[]ipv6.Message)
   211  }
   212  
   213  var (
   214  	// If compilation fails here these are no longer the same underlying type.
   215  	_ ipv6.Message = ipv4.Message{}
   216  )
   217  
   218  type batchReader interface {
   219  	ReadBatch([]ipv6.Message, int) (int, error)
   220  }
   221  
   222  type batchWriter interface {
   223  	WriteBatch([]ipv6.Message, int) (int, error)
   224  }
   225  
   226  func (s *StdNetBind) receiveIP(
   227  	br batchReader,
   228  	conn *net.UDPConn,
   229  	rxOffload bool,
   230  	bufs [][]byte,
   231  	sizes []int,
   232  	eps []Endpoint,
   233  ) (n int, err error) {
   234  	msgs := s.getMessages()
   235  	for i := range bufs {
   236  		(*msgs)[i].Buffers[0] = bufs[i]
   237  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
   238  	}
   239  	defer s.putMessages(msgs)
   240  	var numMsgs int
   241  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   242  		if rxOffload {
   243  			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
   244  			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
   245  			if err != nil {
   246  				return 0, err
   247  			}
   248  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
   249  			if err != nil {
   250  				return 0, err
   251  			}
   252  		} else {
   253  			numMsgs, err = br.ReadBatch(*msgs, 0)
   254  			if err != nil {
   255  				return 0, err
   256  			}
   257  		}
   258  	} else {
   259  		msg := &(*msgs)[0]
   260  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   261  		if err != nil {
   262  			return 0, err
   263  		}
   264  		numMsgs = 1
   265  	}
   266  	for i := 0; i < numMsgs; i++ {
   267  		msg := &(*msgs)[i]
   268  		sizes[i] = msg.N
   269  		if sizes[i] == 0 {
   270  			continue
   271  		}
   272  		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   273  		ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   274  		getSrcFromControl(msg.OOB[:msg.NN], ep)
   275  		eps[i] = ep
   276  	}
   277  	return numMsgs, nil
   278  }
   279  
   280  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   281  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   282  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   283  	}
   284  }
   285  
   286  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   287  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   288  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   289  	}
   290  }
   291  
   292  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   293  // rename the IdealBatchSize constant to BatchSize.
   294  func (s *StdNetBind) BatchSize() int {
   295  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   296  		return IdealBatchSize
   297  	}
   298  	return 1
   299  }
   300  
   301  func (s *StdNetBind) GetOffloadInfo() string {
   302  	return fmt.Sprintf("ipv4TxOffload: %v, ipv4RxOffload: %v\nipv6TxOffload: %v, ipv6RxOffload: %v",
   303  		s.ipv4TxOffload, s.ipv4RxOffload, s.ipv6TxOffload, s.ipv6RxOffload)
   304  }
   305  
   306  func (s *StdNetBind) Close() error {
   307  	s.mu.Lock()
   308  	defer s.mu.Unlock()
   309  
   310  	var err1, err2 error
   311  	if s.ipv4 != nil {
   312  		err1 = s.ipv4.Close()
   313  		s.ipv4 = nil
   314  		s.ipv4PC = nil
   315  	}
   316  	if s.ipv6 != nil {
   317  		err2 = s.ipv6.Close()
   318  		s.ipv6 = nil
   319  		s.ipv6PC = nil
   320  	}
   321  	s.blackhole4 = false
   322  	s.blackhole6 = false
   323  	s.ipv4TxOffload = false
   324  	s.ipv4RxOffload = false
   325  	s.ipv6TxOffload = false
   326  	s.ipv6RxOffload = false
   327  	if err1 != nil {
   328  		return err1
   329  	}
   330  	return err2
   331  }
   332  
   333  type ErrUDPGSODisabled struct {
   334  	onLaddr  string
   335  	RetryErr error
   336  }
   337  
   338  func (e ErrUDPGSODisabled) Error() string {
   339  	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
   340  }
   341  
   342  func (e ErrUDPGSODisabled) Unwrap() error {
   343  	return e.RetryErr
   344  }
   345  
   346  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
   347  	s.mu.Lock()
   348  	blackhole := s.blackhole4
   349  	conn := s.ipv4
   350  	offload := s.ipv4TxOffload
   351  	br := batchWriter(s.ipv4PC)
   352  	is6 := false
   353  	if endpoint.DstIP().Is6() {
   354  		blackhole = s.blackhole6
   355  		conn = s.ipv6
   356  		br = s.ipv6PC
   357  		is6 = true
   358  		offload = s.ipv6TxOffload
   359  	}
   360  	s.mu.Unlock()
   361  
   362  	if blackhole {
   363  		return nil
   364  	}
   365  	if conn == nil {
   366  		return syscall.EAFNOSUPPORT
   367  	}
   368  
   369  	msgs := s.getMessages()
   370  	defer s.putMessages(msgs)
   371  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   372  	defer s.udpAddrPool.Put(ua)
   373  	if is6 {
   374  		as16 := endpoint.DstIP().As16()
   375  		copy(ua.IP, as16[:])
   376  		ua.IP = ua.IP[:16]
   377  	} else {
   378  		as4 := endpoint.DstIP().As4()
   379  		copy(ua.IP, as4[:])
   380  		ua.IP = ua.IP[:4]
   381  	}
   382  	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
   383  	var (
   384  		retried bool
   385  		err     error
   386  	)
   387  retry:
   388  	if offload {
   389  		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
   390  		err = s.send(conn, br, (*msgs)[:n])
   391  		if err != nil && offload && errShouldDisableUDPGSO(err) {
   392  			offload = false
   393  			s.mu.Lock()
   394  			if is6 {
   395  				s.ipv6TxOffload = false
   396  			} else {
   397  				s.ipv4TxOffload = false
   398  			}
   399  			s.mu.Unlock()
   400  			retried = true
   401  			goto retry
   402  		}
   403  	} else {
   404  		for i := range bufs {
   405  			(*msgs)[i].Addr = ua
   406  			(*msgs)[i].Buffers[0] = bufs[i]
   407  			setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
   408  		}
   409  		err = s.send(conn, br, (*msgs)[:len(bufs)])
   410  	}
   411  	if retried {
   412  		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
   413  	}
   414  	return err
   415  }
   416  
   417  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
   418  	var (
   419  		n     int
   420  		err   error
   421  		start int
   422  	)
   423  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   424  		for {
   425  			n, err = pc.WriteBatch(msgs[start:], 0)
   426  			if err != nil || n == len(msgs[start:]) {
   427  				break
   428  			}
   429  			start += n
   430  		}
   431  	} else {
   432  		for _, msg := range msgs {
   433  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
   434  			if err != nil {
   435  				break
   436  			}
   437  		}
   438  	}
   439  	return err
   440  }
   441  
   442  const (
   443  	// Exceeding these values results in EMSGSIZE. They account for layer3 and
   444  	// layer4 headers. IPv6 does not need to account for itself as the payload
   445  	// length field is self excluding.
   446  	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
   447  	maxIPv6PayloadLen = 1<<16 - 1 - 8
   448  
   449  	// This is a hard limit imposed by the kernel.
   450  	udpSegmentMaxDatagrams = 64
   451  )
   452  
   453  type setGSOFunc func(control *[]byte, gsoSize uint16)
   454  
   455  func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
   456  	var (
   457  		base     = -1 // index of msg we are currently coalescing into
   458  		gsoSize  int  // segmentation size of msgs[base]
   459  		dgramCnt int  // number of dgrams coalesced into msgs[base]
   460  		endBatch bool // tracking flag to start a new batch on next iteration of bufs
   461  	)
   462  	maxPayloadLen := maxIPv4PayloadLen
   463  	if ep.DstIP().Is6() {
   464  		maxPayloadLen = maxIPv6PayloadLen
   465  	}
   466  	for i, buf := range bufs {
   467  		if i > 0 {
   468  			msgLen := len(buf)
   469  			baseLenBefore := len(msgs[base].Buffers[0])
   470  			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
   471  			if msgLen+baseLenBefore <= maxPayloadLen &&
   472  				msgLen <= gsoSize &&
   473  				msgLen <= freeBaseCap &&
   474  				dgramCnt < udpSegmentMaxDatagrams &&
   475  				!endBatch {
   476  				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
   477  				if i == len(bufs)-1 {
   478  					setGSO(&msgs[base].OOB, uint16(gsoSize))
   479  				}
   480  				dgramCnt++
   481  				if msgLen < gsoSize {
   482  					// A smaller than gsoSize packet on the tail is legal, but
   483  					// it must end the batch.
   484  					endBatch = true
   485  				}
   486  				continue
   487  			}
   488  		}
   489  		if dgramCnt > 1 {
   490  			setGSO(&msgs[base].OOB, uint16(gsoSize))
   491  		}
   492  		// Reset prior to incrementing base since we are preparing to start a
   493  		// new potential batch.
   494  		endBatch = false
   495  		base++
   496  		gsoSize = len(buf)
   497  		setSrcControl(&msgs[base].OOB, ep)
   498  		msgs[base].Buffers[0] = buf
   499  		msgs[base].Addr = addr
   500  		dgramCnt = 1
   501  	}
   502  	return base + 1
   503  }
   504  
   505  type getGSOFunc func(control []byte) (int, error)
   506  
   507  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
   508  	for i := firstMsgAt; i < len(msgs); i++ {
   509  		msg := &msgs[i]
   510  		if msg.N == 0 {
   511  			return n, err
   512  		}
   513  		var (
   514  			gsoSize    int
   515  			start      int
   516  			end        = msg.N
   517  			numToSplit = 1
   518  		)
   519  		gsoSize, err = getGSO(msg.OOB[:msg.NN])
   520  		if err != nil {
   521  			return n, err
   522  		}
   523  		if gsoSize > 0 {
   524  			numToSplit = (msg.N + gsoSize - 1) / gsoSize
   525  			end = gsoSize
   526  		}
   527  		for j := 0; j < numToSplit; j++ {
   528  			if n > i {
   529  				return n, errors.New("splitting coalesced packet resulted in overflow")
   530  			}
   531  			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
   532  			msgs[n].N = copied
   533  			msgs[n].Addr = msg.Addr
   534  			start = end
   535  			end += gsoSize
   536  			if end > msg.N {
   537  				end = msg.N
   538  			}
   539  			n++
   540  		}
   541  		if i != n-1 {
   542  			// It is legal for bytes to move within msg.Buffers[0] as a result
   543  			// of splitting, so we only zero the source msg len when it is not
   544  			// the destination of the last split operation above.
   545  			msg.N = 0
   546  		}
   547  	}
   548  	return n, nil
   549  }