github.com/sagernet/wireguard-go@v0.0.0-20231215174105-89dec3b2f3e8/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"net/netip"
    13  	"runtime"
    14  	"strconv"
    15  	"sync"
    16  	"syscall"
    17  
    18  	"github.com/sagernet/sing/common"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  	"golang.org/x/net/ipv4"
    21  	"golang.org/x/net/ipv6"
    22  )
    23  
    24  var _ Bind = (*StdNetBind)(nil)
    25  
    26  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
    27  // (see bind_windows.go), it may fall back to StdNetBind.
    28  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
    29  // methods for sending and receiving multiple datagrams per-syscall. See the
    30  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
    31  type StdNetBind struct {
    32  	listener            Listener
    33  	reservedForEndpoint map[netip.AddrPort][3]uint8
    34  
    35  	mu            sync.Mutex // protects all fields except as specified
    36  	ipv4          *net.UDPConn
    37  	ipv6          *net.UDPConn
    38  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
    39  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
    40  	ipv4TxOffload bool
    41  	ipv4RxOffload bool
    42  	ipv6TxOffload bool
    43  	ipv6RxOffload bool
    44  
    45  	// these two fields are not guarded by mu
    46  	udpAddrPool sync.Pool
    47  	msgsPool    sync.Pool
    48  
    49  	blackhole4 bool
    50  	blackhole6 bool
    51  }
    52  
    53  func NewStdNetBind(listener Listener) Bind {
    54  	return &StdNetBind{
    55  		listener: listener,
    56  
    57  		udpAddrPool: sync.Pool{
    58  			New: func() any {
    59  				return &net.UDPAddr{
    60  					IP: make([]byte, 16),
    61  				}
    62  			},
    63  		},
    64  
    65  		msgsPool: sync.Pool{
    66  			New: func() any {
    67  				// ipv6.Message and ipv4.Message are interchangeable as they are
    68  				// both aliases for x/net/internal/socket.Message.
    69  				msgs := make([]ipv6.Message, IdealBatchSize)
    70  				for i := range msgs {
    71  					msgs[i].Buffers = make(net.Buffers, 1)
    72  					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
    73  				}
    74  				return &msgs
    75  			},
    76  		},
    77  	}
    78  }
    79  
    80  type StdNetEndpoint struct {
    81  	// AddrPort is the endpoint destination.
    82  	netip.AddrPort
    83  	// src is the current sticky source address and interface index, if
    84  	// supported. Typically this is a PKTINFO structure from/for control
    85  	// messages, see unix.PKTINFO for an example.
    86  	src []byte
    87  }
    88  
    89  var (
    90  	_ Bind     = (*StdNetBind)(nil)
    91  	_ Endpoint = &StdNetEndpoint{}
    92  )
    93  
    94  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    95  	e, err := netip.ParseAddrPort(s)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	return &StdNetEndpoint{
   100  		AddrPort: e,
   101  	}, nil
   102  }
   103  
   104  func (e *StdNetEndpoint) ClearSrc() {
   105  	if e.src != nil {
   106  		// Truncate src, no need to reallocate.
   107  		e.src = e.src[:0]
   108  	}
   109  }
   110  
   111  func (e *StdNetEndpoint) DstIP() netip.Addr {
   112  	return e.AddrPort.Addr()
   113  }
   114  
   115  // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
   116  
   117  func (e *StdNetEndpoint) DstToBytes() []byte {
   118  	b, _ := e.AddrPort.MarshalBinary()
   119  	return b
   120  }
   121  
   122  func (e *StdNetEndpoint) DstToString() string {
   123  	return e.AddrPort.String()
   124  }
   125  
   126  func listenNet(listener Listener, network string, port int) (*net.UDPConn, int, error) {
   127  	conn, err := listener.ListenPacketCompat(network, ":"+strconv.Itoa(port))
   128  	if err != nil {
   129  		return nil, 0, err
   130  	}
   131  
   132  	// Retrieve port.
   133  	laddr := conn.LocalAddr()
   134  	uaddr, err := net.ResolveUDPAddr(
   135  		laddr.Network(),
   136  		laddr.String(),
   137  	)
   138  	if err != nil {
   139  		return nil, 0, err
   140  	}
   141  	return conn.(*net.UDPConn), uaddr.Port, nil
   142  }
   143  
   144  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
   145  	s.mu.Lock()
   146  	defer s.mu.Unlock()
   147  
   148  	var err error
   149  	var tries int
   150  
   151  	if s.ipv4 != nil || s.ipv6 != nil {
   152  		return nil, 0, ErrBindAlreadyOpen
   153  	}
   154  
   155  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   156  	// If uport is 0, we can retry on failure.
   157  again:
   158  	port := int(uport)
   159  	var v4conn, v6conn *net.UDPConn
   160  	var v4pc *ipv4.PacketConn
   161  	var v6pc *ipv6.PacketConn
   162  
   163  	v4conn, port, err = listenNet(s.listener, "udp4", port)
   164  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   165  		return nil, 0, err
   166  	}
   167  
   168  	// Listen on the same port as we're using for ipv4.
   169  	v6conn, port, err = listenNet(s.listener, "udp6", port)
   170  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   171  		v4conn.Close()
   172  		tries++
   173  		goto again
   174  	}
   175  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   176  		v4conn.Close()
   177  		return nil, 0, err
   178  	}
   179  	var fns []ReceiveFunc
   180  	if v4conn != nil {
   181  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
   182  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   183  			v4pc = ipv4.NewPacketConn(v4conn)
   184  			s.ipv4PC = v4pc
   185  		}
   186  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
   187  		s.ipv4 = v4conn
   188  	}
   189  	if v6conn != nil {
   190  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
   191  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   192  			v6pc = ipv6.NewPacketConn(v6conn)
   193  			s.ipv6PC = v6pc
   194  		}
   195  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
   196  		s.ipv6 = v6conn
   197  	}
   198  	if len(fns) == 0 {
   199  		return nil, 0, syscall.EAFNOSUPPORT
   200  	}
   201  
   202  	return fns, uint16(port), nil
   203  }
   204  
   205  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
   206  	for i := range *msgs {
   207  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
   208  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
   209  	}
   210  	s.msgsPool.Put(msgs)
   211  }
   212  
   213  func (s *StdNetBind) getMessages() *[]ipv6.Message {
   214  	return s.msgsPool.Get().(*[]ipv6.Message)
   215  }
   216  
   217  // If compilation fails here these are no longer the same underlying type.
   218  var _ ipv6.Message = ipv4.Message{}
   219  
   220  type batchReader interface {
   221  	ReadBatch([]ipv6.Message, int) (int, error)
   222  }
   223  
   224  type batchWriter interface {
   225  	WriteBatch([]ipv6.Message, int) (int, error)
   226  }
   227  
   228  func (s *StdNetBind) receiveIP(
   229  	br batchReader,
   230  	conn *net.UDPConn,
   231  	rxOffload bool,
   232  	bufs [][]byte,
   233  	sizes []int,
   234  	eps []Endpoint,
   235  ) (n int, err error) {
   236  	msgs := s.getMessages()
   237  	for i := range bufs {
   238  		(*msgs)[i].Buffers[0] = bufs[i]
   239  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
   240  	}
   241  	defer s.putMessages(msgs)
   242  	var numMsgs int
   243  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   244  		if rxOffload {
   245  			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
   246  			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
   247  			if err != nil {
   248  				return 0, err
   249  			}
   250  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
   251  			if err != nil {
   252  				return 0, err
   253  			}
   254  		} else {
   255  			numMsgs, err = br.ReadBatch(*msgs, 0)
   256  			if err != nil {
   257  				return 0, err
   258  			}
   259  		}
   260  	} else {
   261  		msg := &(*msgs)[0]
   262  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   263  		if err != nil {
   264  			return 0, err
   265  		}
   266  		numMsgs = 1
   267  	}
   268  	for i := 0; i < numMsgs; i++ {
   269  		msg := &(*msgs)[i]
   270  		sizes[i] = msg.N
   271  		if sizes[i] == 0 {
   272  			continue
   273  		}
   274  		if msg.N > 3 {
   275  			common.ClearArray(bufs[i][1:4])
   276  		}
   277  		ep := &StdNetEndpoint{AddrPort: M.AddrPortFromNet(msg.Addr)} // TODO: remove allocation
   278  		getSrcFromControl(msg.OOB[:msg.NN], ep)
   279  		eps[i] = ep
   280  	}
   281  	return numMsgs, nil
   282  }
   283  
   284  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   285  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   286  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   287  	}
   288  }
   289  
   290  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
   291  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   292  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   293  	}
   294  }
   295  
   296  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   297  // rename the IdealBatchSize constant to BatchSize.
   298  func (s *StdNetBind) BatchSize() int {
   299  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   300  		return IdealBatchSize
   301  	}
   302  	return 1
   303  }
   304  
   305  func (s *StdNetBind) Close() error {
   306  	s.mu.Lock()
   307  	defer s.mu.Unlock()
   308  
   309  	var err1, err2 error
   310  	if s.ipv4 != nil {
   311  		err1 = s.ipv4.Close()
   312  		s.ipv4 = nil
   313  		s.ipv4PC = nil
   314  	}
   315  	if s.ipv6 != nil {
   316  		err2 = s.ipv6.Close()
   317  		s.ipv6 = nil
   318  		s.ipv6PC = nil
   319  	}
   320  	s.blackhole4 = false
   321  	s.blackhole6 = false
   322  	s.ipv4TxOffload = false
   323  	s.ipv4RxOffload = false
   324  	s.ipv6TxOffload = false
   325  	s.ipv6RxOffload = false
   326  	if err1 != nil {
   327  		return err1
   328  	}
   329  	return err2
   330  }
   331  
   332  type ErrUDPGSODisabled struct {
   333  	onLaddr  string
   334  	RetryErr error
   335  }
   336  
   337  func (e ErrUDPGSODisabled) Error() string {
   338  	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
   339  }
   340  
   341  func (e ErrUDPGSODisabled) Unwrap() error {
   342  	return e.RetryErr
   343  }
   344  
   345  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
   346  	s.mu.Lock()
   347  	blackhole := s.blackhole4
   348  	conn := s.ipv4
   349  	offload := s.ipv4TxOffload
   350  	br := batchWriter(s.ipv4PC)
   351  	is6 := false
   352  	if endpoint.DstIP().Is6() {
   353  		blackhole = s.blackhole6
   354  		conn = s.ipv6
   355  		br = s.ipv6PC
   356  		is6 = true
   357  		offload = s.ipv6TxOffload
   358  	}
   359  	s.mu.Unlock()
   360  
   361  	if blackhole {
   362  		return nil
   363  	}
   364  	if conn == nil {
   365  		return syscall.EAFNOSUPPORT
   366  	}
   367  
   368  	msgs := s.getMessages()
   369  	defer s.putMessages(msgs)
   370  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   371  	defer s.udpAddrPool.Put(ua)
   372  	if is6 {
   373  		as16 := endpoint.DstIP().As16()
   374  		copy(ua.IP, as16[:])
   375  		ua.IP = ua.IP[:16]
   376  	} else {
   377  		as4 := endpoint.DstIP().As4()
   378  		copy(ua.IP, as4[:])
   379  		ua.IP = ua.IP[:4]
   380  	}
   381  	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
   382  	var (
   383  		retried bool
   384  		err     error
   385  	)
   386  retry:
   387  	if offload {
   388  		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
   389  		err = s.send(conn, br, (*msgs)[:n])
   390  		if err != nil && offload && errShouldDisableUDPGSO(err) {
   391  			offload = false
   392  			s.mu.Lock()
   393  			if is6 {
   394  				s.ipv6TxOffload = false
   395  			} else {
   396  				s.ipv4TxOffload = false
   397  			}
   398  			s.mu.Unlock()
   399  			retried = true
   400  			goto retry
   401  		}
   402  	} else {
   403  		for i := range bufs {
   404  			bufI := bufs[i]
   405  			if len(bufI) > 3 {
   406  				reserved, loaded := s.reservedForEndpoint[endpoint.(*StdNetEndpoint).AddrPort]
   407  				if loaded {
   408  					copy(bufI[1:4], reserved[:])
   409  				}
   410  			}
   411  
   412  			(*msgs)[i].Addr = ua
   413  			(*msgs)[i].Buffers[0] = bufs[i]
   414  			setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
   415  		}
   416  		err = s.send(conn, br, (*msgs)[:len(bufs)])
   417  	}
   418  	if retried {
   419  		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
   420  	}
   421  	return err
   422  }
   423  
   424  func (s *StdNetBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
   425  	s.reservedForEndpoint[destination] = reserved
   426  }
   427  
   428  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
   429  	var (
   430  		n     int
   431  		err   error
   432  		start int
   433  	)
   434  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   435  		for {
   436  			n, err = pc.WriteBatch(msgs[start:], 0)
   437  			if err != nil || n == len(msgs[start:]) {
   438  				break
   439  			}
   440  			start += n
   441  		}
   442  	} else {
   443  		for _, msg := range msgs {
   444  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
   445  			if err != nil {
   446  				break
   447  			}
   448  		}
   449  	}
   450  	return err
   451  }
   452  
   453  const (
   454  	// Exceeding these values results in EMSGSIZE. They account for layer3 and
   455  	// layer4 headers. IPv6 does not need to account for itself as the payload
   456  	// length field is self excluding.
   457  	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
   458  	maxIPv6PayloadLen = 1<<16 - 1 - 8
   459  
   460  	// This is a hard limit imposed by the kernel.
   461  	udpSegmentMaxDatagrams = 64
   462  )
   463  
   464  type setGSOFunc func(control *[]byte, gsoSize uint16)
   465  
   466  func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
   467  	var (
   468  		base     = -1 // index of msg we are currently coalescing into
   469  		gsoSize  int  // segmentation size of msgs[base]
   470  		dgramCnt int  // number of dgrams coalesced into msgs[base]
   471  		endBatch bool // tracking flag to start a new batch on next iteration of bufs
   472  	)
   473  	maxPayloadLen := maxIPv4PayloadLen
   474  	if ep.DstIP().Is6() {
   475  		maxPayloadLen = maxIPv6PayloadLen
   476  	}
   477  	for i, buf := range bufs {
   478  		if i > 0 {
   479  			msgLen := len(buf)
   480  			baseLenBefore := len(msgs[base].Buffers[0])
   481  			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
   482  			if msgLen+baseLenBefore <= maxPayloadLen &&
   483  				msgLen <= gsoSize &&
   484  				msgLen <= freeBaseCap &&
   485  				dgramCnt < udpSegmentMaxDatagrams &&
   486  				!endBatch {
   487  				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
   488  				if i == len(bufs)-1 {
   489  					setGSO(&msgs[base].OOB, uint16(gsoSize))
   490  				}
   491  				dgramCnt++
   492  				if msgLen < gsoSize {
   493  					// A smaller than gsoSize packet on the tail is legal, but
   494  					// it must end the batch.
   495  					endBatch = true
   496  				}
   497  				continue
   498  			}
   499  		}
   500  		if dgramCnt > 1 {
   501  			setGSO(&msgs[base].OOB, uint16(gsoSize))
   502  		}
   503  		// Reset prior to incrementing base since we are preparing to start a
   504  		// new potential batch.
   505  		endBatch = false
   506  		base++
   507  		gsoSize = len(buf)
   508  		setSrcControl(&msgs[base].OOB, ep)
   509  		msgs[base].Buffers[0] = buf
   510  		msgs[base].Addr = addr
   511  		dgramCnt = 1
   512  	}
   513  	return base + 1
   514  }
   515  
   516  type getGSOFunc func(control []byte) (int, error)
   517  
   518  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
   519  	for i := firstMsgAt; i < len(msgs); i++ {
   520  		msg := &msgs[i]
   521  		if msg.N == 0 {
   522  			return n, err
   523  		}
   524  		var (
   525  			gsoSize    int
   526  			start      int
   527  			end        = msg.N
   528  			numToSplit = 1
   529  		)
   530  		gsoSize, err = getGSO(msg.OOB[:msg.NN])
   531  		if err != nil {
   532  			return n, err
   533  		}
   534  		if gsoSize > 0 {
   535  			numToSplit = (msg.N + gsoSize - 1) / gsoSize
   536  			end = gsoSize
   537  		}
   538  		for j := 0; j < numToSplit; j++ {
   539  			if n > i {
   540  				return n, errors.New("splitting coalesced packet resulted in overflow")
   541  			}
   542  			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
   543  			msgs[n].N = copied
   544  			msgs[n].Addr = msg.Addr
   545  			start = end
   546  			end += gsoSize
   547  			if end > msg.N {
   548  				end = msg.N
   549  			}
   550  			n++
   551  		}
   552  		if i != n-1 {
   553  			// It is legal for bytes to move within msg.Buffers[0] as a result
   554  			// of splitting, so we only zero the source msg len when it is not
   555  			// the destination of the last split operation above.
   556  			msg.N = 0
   557  		}
   558  	}
   559  	return n, nil
   560  }