github.com/yaling888/clash@v1.53.0/transport/wireguard/bind_std.go (about)

     1  //go:build !nogvisor
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/netip"
    11  	"runtime"
    12  	"strconv"
    13  	"sync"
    14  	"syscall"
    15  	_ "unsafe"
    16  
    17  	"golang.org/x/net/ipv4"
    18  	"golang.org/x/net/ipv6"
    19  	wg "golang.zx2c4.com/wireguard/conn"
    20  )
    21  
    22  //go:linkname getSrcFromControl golang.zx2c4.com/wireguard/conn.getSrcFromControl
    23  func getSrcFromControl(control []byte, ep *wg.StdNetEndpoint)
    24  
    25  //go:linkname setSrcControl golang.zx2c4.com/wireguard/conn.setSrcControl
    26  func setSrcControl(control *[]byte, ep *wg.StdNetEndpoint)
    27  
    28  //go:linkname getGSOSize golang.zx2c4.com/wireguard/conn.getGSOSize
    29  func getGSOSize(control []byte) (int, error)
    30  
    31  //go:linkname setGSOSize golang.zx2c4.com/wireguard/conn.setGSOSize
    32  func setGSOSize(control *[]byte, gsoSize uint16)
    33  
    34  //go:linkname supportsUDPOffload golang.zx2c4.com/wireguard/conn.supportsUDPOffload
    35  func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool)
    36  
    37  //go:linkname errShouldDisableUDPGSO golang.zx2c4.com/wireguard/conn.errShouldDisableUDPGSO
    38  func errShouldDisableUDPGSO(err error) bool
    39  
    40  //go:linkname coalesceMessages golang.zx2c4.com/wireguard/conn.coalesceMessages
    41  func coalesceMessages(addr *net.UDPAddr, ep *wg.StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int
    42  
    43  //go:linkname splitCoalescedMessages golang.zx2c4.com/wireguard/conn.splitCoalescedMessages
    44  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error)
    45  
    46  const udpSegmentMaxDatagrams = 64 // This is a hard limit imposed by the kernel.
    47  
    48  type setGSOFunc func(control *[]byte, gsoSize uint16)
    49  
    50  type getGSOFunc func(control []byte) (int, error)
    51  
    52  var _ wg.Bind = (*StdNetBind)(nil)
    53  
    54  type StdNetBind struct {
    55  	mu            sync.Mutex // protects all fields except as specified
    56  	ipv4          *net.UDPConn
    57  	ipv6          *net.UDPConn
    58  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
    59  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
    60  	ipv4TxOffload bool
    61  	ipv4RxOffload bool
    62  	ipv6TxOffload bool
    63  	ipv6RxOffload bool
    64  
    65  	// these two fields are not guarded by mu
    66  	udpAddrPool sync.Pool
    67  	msgsPool    sync.Pool
    68  
    69  	blackhole4 bool
    70  	blackhole6 bool
    71  
    72  	controlFns    []func(network, address string, c syscall.RawConn) error
    73  	interfaceName string
    74  	reserved      []byte
    75  }
    76  
    77  func (s *StdNetBind) setReserved(b []byte) {
    78  	if len(b) < 4 || s.reserved == nil {
    79  		return
    80  	}
    81  	b[1] = s.reserved[0]
    82  	b[2] = s.reserved[1]
    83  	b[3] = s.reserved[2]
    84  }
    85  
    86  func (s *StdNetBind) resetReserved(b []byte) {
    87  	if len(b) < 4 {
    88  		return
    89  	}
    90  	b[1] = 0x00
    91  	b[2] = 0x00
    92  	b[3] = 0x00
    93  }
    94  
    95  func (s *StdNetBind) listenConfig() *net.ListenConfig {
    96  	return &net.ListenConfig{
    97  		Control: func(network, address string, c syscall.RawConn) error {
    98  			for _, fn := range s.controlFns {
    99  				if err := fn(network, address, c); err != nil {
   100  					return err
   101  				}
   102  			}
   103  			return nil
   104  		},
   105  	}
   106  }
   107  
   108  func (s *StdNetBind) listenNet(network string, port int) (*net.UDPConn, int, error) {
   109  	listenIP, err := getListenIP(network, s.interfaceName)
   110  	if err != nil {
   111  		return nil, 0, err
   112  	}
   113  
   114  	conn, err := s.listenConfig().ListenPacket(context.Background(), network, listenIP+":"+strconv.Itoa(port))
   115  	if err != nil {
   116  		return nil, 0, err
   117  	}
   118  
   119  	// Retrieve port.
   120  	laddr := conn.LocalAddr()
   121  	uaddr, err := net.ResolveUDPAddr(
   122  		laddr.Network(),
   123  		laddr.String(),
   124  	)
   125  	if err != nil {
   126  		return nil, 0, err
   127  	}
   128  	return conn.(*net.UDPConn), uaddr.Port, nil
   129  }
   130  
   131  func (s *StdNetBind) SetMark(mark uint32) error {
   132  	return nil
   133  }
   134  
   135  func (*StdNetBind) ParseEndpoint(s string) (wg.Endpoint, error) {
   136  	e, err := netip.ParseAddrPort(s)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	return &wg.StdNetEndpoint{
   141  		AddrPort: e,
   142  	}, nil
   143  }
   144  
   145  func (s *StdNetBind) UpdateControlFns(controlFns []func(network, address string, c syscall.RawConn) error) {
   146  	s.controlFns = controlFns
   147  }
   148  
   149  func NewStdNetBind(
   150  	controlFns []func(network, address string, c syscall.RawConn) error,
   151  	interfaceName string,
   152  	reserved []byte,
   153  ) wg.Bind {
   154  	return &StdNetBind{
   155  		udpAddrPool: sync.Pool{
   156  			New: func() any {
   157  				return &net.UDPAddr{
   158  					IP: make([]byte, 16),
   159  				}
   160  			},
   161  		},
   162  
   163  		msgsPool: sync.Pool{
   164  			New: func() any {
   165  				// ipv6.Message and ipv4.Message are interchangeable as they are
   166  				// both aliases for x/net/internal/socket.Message.
   167  				msgs := make([]ipv6.Message, wg.IdealBatchSize)
   168  				for i := range msgs {
   169  					msgs[i].Buffers = make(net.Buffers, 1)
   170  					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
   171  				}
   172  				return &msgs
   173  			},
   174  		},
   175  
   176  		controlFns:    controlFns,
   177  		interfaceName: interfaceName,
   178  		reserved:      reserved,
   179  	}
   180  }
   181  
   182  func (s *StdNetBind) Open(uport uint16) ([]wg.ReceiveFunc, uint16, error) {
   183  	s.mu.Lock()
   184  	defer s.mu.Unlock()
   185  
   186  	var err error
   187  	var tries int
   188  
   189  	if s.ipv4 != nil || s.ipv6 != nil {
   190  		return nil, 0, wg.ErrBindAlreadyOpen
   191  	}
   192  
   193  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   194  	// If uport is 0, we can retry on failure.
   195  again:
   196  	port := int(uport)
   197  	var v4conn, v6conn *net.UDPConn
   198  	var v4pc *ipv4.PacketConn
   199  	var v6pc *ipv6.PacketConn
   200  
   201  	v4conn, port, err = s.listenNet("udp4", port)
   202  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   203  		return nil, 0, err
   204  	}
   205  
   206  	// Listen on the same port as we're using for ipv4.
   207  	v6conn, port, err = s.listenNet("udp6", port)
   208  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   209  		v4conn.Close()
   210  		tries++
   211  		goto again
   212  	}
   213  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   214  		v4conn.Close()
   215  		return nil, 0, err
   216  	}
   217  	var fns []wg.ReceiveFunc
   218  	if v4conn != nil {
   219  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
   220  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   221  			v4pc = ipv4.NewPacketConn(v4conn)
   222  			s.ipv4PC = v4pc
   223  		}
   224  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
   225  		s.ipv4 = v4conn
   226  	}
   227  	if v6conn != nil {
   228  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
   229  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   230  			v6pc = ipv6.NewPacketConn(v6conn)
   231  			s.ipv6PC = v6pc
   232  		}
   233  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
   234  		s.ipv6 = v6conn
   235  	}
   236  	if len(fns) == 0 {
   237  		return nil, 0, syscall.EAFNOSUPPORT
   238  	}
   239  
   240  	return fns, uint16(port), nil
   241  }
   242  
   243  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
   244  	for i := range *msgs {
   245  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
   246  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
   247  	}
   248  	s.msgsPool.Put(msgs)
   249  }
   250  
   251  func (s *StdNetBind) getMessages() *[]ipv6.Message {
   252  	return s.msgsPool.Get().(*[]ipv6.Message)
   253  }
   254  
   255  var _ ipv6.Message = ipv4.Message{} // If compilation fails here these are no longer the same underlying type.
   256  
   257  type batchReader interface {
   258  	ReadBatch([]ipv6.Message, int) (int, error)
   259  }
   260  
   261  type batchWriter interface {
   262  	WriteBatch([]ipv6.Message, int) (int, error)
   263  }
   264  
   265  func (s *StdNetBind) receiveIP(
   266  	br batchReader,
   267  	conn *net.UDPConn,
   268  	rxOffload bool,
   269  	bufs [][]byte,
   270  	sizes []int,
   271  	eps []wg.Endpoint,
   272  ) (numMsgs int, err error) {
   273  	msgs := s.getMessages()
   274  	for i := range bufs {
   275  		(*msgs)[i].Buffers[0] = bufs[i]
   276  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
   277  	}
   278  	defer s.putMessages(msgs)
   279  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   280  		if rxOffload {
   281  			readAt := len(*msgs) - (wg.IdealBatchSize / udpSegmentMaxDatagrams)
   282  			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
   283  			if err != nil {
   284  				return 0, err
   285  			}
   286  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
   287  			if err != nil {
   288  				return 0, err
   289  			}
   290  		} else {
   291  			numMsgs, err = br.ReadBatch(*msgs, 0)
   292  			if err != nil {
   293  				return 0, err
   294  			}
   295  		}
   296  	} else {
   297  		msg := &(*msgs)[0]
   298  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   299  		if err != nil {
   300  			return 0, err
   301  		}
   302  		numMsgs = 1
   303  	}
   304  	for i := 0; i < numMsgs; i++ {
   305  		msg := &(*msgs)[i]
   306  		sizes[i] = msg.N
   307  		if sizes[i] == 0 {
   308  			continue
   309  		}
   310  		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   311  		ep := &wg.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   312  		getSrcFromControl(msg.OOB[:msg.NN], ep)
   313  		eps[i] = ep
   314  		s.resetReserved(msg.Buffers[0])
   315  	}
   316  	return numMsgs, nil
   317  }
   318  
   319  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wg.ReceiveFunc {
   320  	return func(bufs [][]byte, sizes []int, eps []wg.Endpoint) (n int, err error) {
   321  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   322  	}
   323  }
   324  
   325  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) wg.ReceiveFunc {
   326  	return func(bufs [][]byte, sizes []int, eps []wg.Endpoint) (n int, err error) {
   327  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
   328  	}
   329  }
   330  
   331  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   332  // rename the IdealBatchSize constant to BatchSize.
   333  func (s *StdNetBind) BatchSize() int {
   334  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   335  		return wg.IdealBatchSize
   336  	}
   337  	return 1
   338  }
   339  
   340  func (s *StdNetBind) Close() error {
   341  	s.mu.Lock()
   342  	defer s.mu.Unlock()
   343  
   344  	var err1, err2 error
   345  	if s.ipv4 != nil {
   346  		err1 = s.ipv4.Close()
   347  		s.ipv4 = nil
   348  		s.ipv4PC = nil
   349  	}
   350  	if s.ipv6 != nil {
   351  		err2 = s.ipv6.Close()
   352  		s.ipv6 = nil
   353  		s.ipv6PC = nil
   354  	}
   355  	s.blackhole4 = false
   356  	s.blackhole6 = false
   357  	s.ipv4TxOffload = false
   358  	s.ipv4RxOffload = false
   359  	s.ipv6TxOffload = false
   360  	s.ipv6RxOffload = false
   361  	if err1 != nil {
   362  		return err1
   363  	}
   364  	return err2
   365  }
   366  
   367  func (s *StdNetBind) Send(bufs [][]byte, endpoint wg.Endpoint) error {
   368  	s.mu.Lock()
   369  	blackhole := s.blackhole4
   370  	conn := s.ipv4
   371  	offload := s.ipv4TxOffload
   372  	br := batchWriter(s.ipv4PC)
   373  	is6 := false
   374  	if endpoint.DstIP().Is6() {
   375  		blackhole = s.blackhole6
   376  		conn = s.ipv6
   377  		br = s.ipv6PC
   378  		is6 = true
   379  		offload = s.ipv6TxOffload
   380  	}
   381  	s.mu.Unlock()
   382  
   383  	if blackhole {
   384  		return nil
   385  	}
   386  	if conn == nil {
   387  		return syscall.EAFNOSUPPORT
   388  	}
   389  
   390  	for i := range bufs {
   391  		s.setReserved(bufs[i])
   392  	}
   393  
   394  	msgs := s.getMessages()
   395  	defer s.putMessages(msgs)
   396  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   397  	defer s.udpAddrPool.Put(ua)
   398  	if is6 {
   399  		as16 := endpoint.DstIP().As16()
   400  		copy(ua.IP, as16[:])
   401  		ua.IP = ua.IP[:16]
   402  	} else {
   403  		as4 := endpoint.DstIP().As4()
   404  		copy(ua.IP, as4[:])
   405  		ua.IP = ua.IP[:4]
   406  	}
   407  	ua.Port = int(endpoint.(*wg.StdNetEndpoint).Port())
   408  	var (
   409  		retried bool
   410  		err     error
   411  	)
   412  retry:
   413  	if offload {
   414  		n := coalesceMessages(ua, endpoint.(*wg.StdNetEndpoint), bufs, *msgs, setGSOSize)
   415  		err = s.send(conn, br, (*msgs)[:n])
   416  		if err != nil && offload && errShouldDisableUDPGSO(err) {
   417  			offload = false
   418  			s.mu.Lock()
   419  			if is6 {
   420  				s.ipv6TxOffload = false
   421  			} else {
   422  				s.ipv4TxOffload = false
   423  			}
   424  			s.mu.Unlock()
   425  			retried = true
   426  			goto retry
   427  		}
   428  	} else {
   429  		for i := range bufs {
   430  			(*msgs)[i].Addr = ua
   431  			(*msgs)[i].Buffers[0] = bufs[i]
   432  			setSrcControl(&(*msgs)[i].OOB, endpoint.(*wg.StdNetEndpoint))
   433  		}
   434  		err = s.send(conn, br, (*msgs)[:len(bufs)])
   435  	}
   436  	if retried {
   437  		return wg.ErrUDPGSODisabled{RetryErr: fmt.Errorf("disabled UDP GSO on %s, %w", conn.LocalAddr().String(), err)}
   438  	}
   439  	return err
   440  }
   441  
   442  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
   443  	var (
   444  		n     int
   445  		err   error
   446  		start int
   447  	)
   448  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   449  		for {
   450  			n, err = pc.WriteBatch(msgs[start:], 0)
   451  			if err != nil || n == len(msgs[start:]) {
   452  				break
   453  			}
   454  			start += n
   455  		}
   456  	} else {
   457  		for _, msg := range msgs {
   458  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
   459  			if err != nil {
   460  				break
   461  			}
   462  		}
   463  	}
   464  	return err
   465  }