github.com/forest33/wtun@v0.3.1/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"net"
    12  	"net/netip"
    13  	"runtime"
    14  	"strconv"
    15  	"sync"
    16  	"syscall"
    17  
    18  	"golang.org/x/net/ipv4"
    19  	"golang.org/x/net/ipv6"
    20  )
    21  
    22  var (
    23  	_ Bind = (*StdNetBind)(nil)
    24  )
    25  
    26  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
    27  // (see bind_windows.go), it may fall back to StdNetBind.
    28  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
    29  // methods for sending and receiving multiple datagrams per-syscall. See the
    30  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
    31  type StdNetBind struct {
    32  	mu     sync.Mutex // protects all fields except as specified
    33  	ipv4   *net.UDPConn
    34  	ipv6   *net.UDPConn
    35  	ipv4PC *ipv4.PacketConn // will be nil on non-Linux
    36  	ipv6PC *ipv6.PacketConn // will be nil on non-Linux
    37  
    38  	// these three fields are not guarded by mu
    39  	udpAddrPool  sync.Pool
    40  	ipv4MsgsPool sync.Pool
    41  	ipv6MsgsPool sync.Pool
    42  
    43  	blackhole4 bool
    44  	blackhole6 bool
    45  }
    46  
    47  func NewStdNetBind() Bind {
    48  	return &StdNetBind{
    49  		udpAddrPool: sync.Pool{
    50  			New: func() any {
    51  				return &net.UDPAddr{
    52  					IP: make([]byte, 16),
    53  				}
    54  			},
    55  		},
    56  
    57  		ipv4MsgsPool: sync.Pool{
    58  			New: func() any {
    59  				msgs := make([]ipv4.Message, IdealBatchSize)
    60  				for i := range msgs {
    61  					msgs[i].Buffers = make(net.Buffers, 1)
    62  					msgs[i].OOB = make([]byte, srcControlSize)
    63  				}
    64  				return &msgs
    65  			},
    66  		},
    67  
    68  		ipv6MsgsPool: sync.Pool{
    69  			New: func() any {
    70  				msgs := make([]ipv6.Message, IdealBatchSize)
    71  				for i := range msgs {
    72  					msgs[i].Buffers = make(net.Buffers, 1)
    73  					msgs[i].OOB = make([]byte, srcControlSize)
    74  				}
    75  				return &msgs
    76  			},
    77  		},
    78  	}
    79  }
    80  
    81  type StdNetEndpoint struct {
    82  	// AddrPort is the endpoint destination.
    83  	netip.AddrPort
    84  	// src is the current sticky source address and interface index, if
    85  	// supported. Typically this is a PKTINFO structure from/for control
    86  	// messages, see unix.PKTINFO for an example.
    87  	src []byte
    88  }
    89  
    90  var (
    91  	_ Bind     = (*StdNetBind)(nil)
    92  	_ Endpoint = &StdNetEndpoint{}
    93  )
    94  
    95  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    96  	e, err := netip.ParseAddrPort(s)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	return &StdNetEndpoint{
   101  		AddrPort: e,
   102  	}, nil
   103  }
   104  
   105  func (e *StdNetEndpoint) ClearSrc() {
   106  	if e.src != nil {
   107  		// Truncate src, no need to reallocate.
   108  		e.src = e.src[:0]
   109  	}
   110  }
   111  
   112  func (e *StdNetEndpoint) DstIP() netip.Addr {
   113  	return e.AddrPort.Addr()
   114  }
   115  
   116  // See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
   117  
   118  func (e *StdNetEndpoint) DstToBytes() []byte {
   119  	b, _ := e.AddrPort.MarshalBinary()
   120  	return b
   121  }
   122  
   123  func (e *StdNetEndpoint) DstToString() string {
   124  	return e.AddrPort.String()
   125  }
   126  
   127  func listenNet(network string, port int) (*net.UDPConn, int, error) {
   128  	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
   129  	if err != nil {
   130  		return nil, 0, err
   131  	}
   132  
   133  	// Retrieve port.
   134  	laddr := conn.LocalAddr()
   135  	uaddr, err := net.ResolveUDPAddr(
   136  		laddr.Network(),
   137  		laddr.String(),
   138  	)
   139  	if err != nil {
   140  		return nil, 0, err
   141  	}
   142  	return conn.(*net.UDPConn), uaddr.Port, nil
   143  }
   144  
   145  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
   146  	s.mu.Lock()
   147  	defer s.mu.Unlock()
   148  
   149  	var err error
   150  	var tries int
   151  
   152  	if s.ipv4 != nil || s.ipv6 != nil {
   153  		return nil, 0, ErrBindAlreadyOpen
   154  	}
   155  
   156  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   157  	// If uport is 0, we can retry on failure.
   158  again:
   159  	port := int(uport)
   160  	var v4conn, v6conn *net.UDPConn
   161  	var v4pc *ipv4.PacketConn
   162  	var v6pc *ipv6.PacketConn
   163  
   164  	v4conn, port, err = listenNet("udp4", port)
   165  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   166  		return nil, 0, err
   167  	}
   168  
   169  	// Listen on the same port as we're using for ipv4.
   170  	v6conn, port, err = listenNet("udp6", port)
   171  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   172  		v4conn.Close()
   173  		tries++
   174  		goto again
   175  	}
   176  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   177  		v4conn.Close()
   178  		return nil, 0, err
   179  	}
   180  	var fns []ReceiveFunc
   181  	if v4conn != nil {
   182  		if runtime.GOOS == "linux" {
   183  			v4pc = ipv4.NewPacketConn(v4conn)
   184  			s.ipv4PC = v4pc
   185  		}
   186  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
   187  		s.ipv4 = v4conn
   188  	}
   189  	if v6conn != nil {
   190  		if runtime.GOOS == "linux" {
   191  			v6pc = ipv6.NewPacketConn(v6conn)
   192  			s.ipv6PC = v6pc
   193  		}
   194  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
   195  		s.ipv6 = v6conn
   196  	}
   197  	if len(fns) == 0 {
   198  		return nil, 0, syscall.EAFNOSUPPORT
   199  	}
   200  
   201  	return fns, uint16(port), nil
   202  }
   203  
   204  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
   205  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   206  		msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
   207  		defer s.ipv4MsgsPool.Put(msgs)
   208  		for i := range bufs {
   209  			(*msgs)[i].Buffers[0] = bufs[i]
   210  		}
   211  		var numMsgs int
   212  		if runtime.GOOS == "linux" {
   213  			numMsgs, err = pc.ReadBatch(*msgs, 0)
   214  			if err != nil {
   215  				return 0, err
   216  			}
   217  		} else {
   218  			msg := &(*msgs)[0]
   219  			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   220  			if err != nil {
   221  				return 0, err
   222  			}
   223  			numMsgs = 1
   224  		}
   225  		for i := 0; i < numMsgs; i++ {
   226  			msg := &(*msgs)[i]
   227  			sizes[i] = msg.N
   228  			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   229  			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   230  			getSrcFromControl(msg.OOB[:msg.NN], ep)
   231  			eps[i] = ep
   232  		}
   233  		return numMsgs, nil
   234  	}
   235  }
   236  
   237  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
   238  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
   239  		msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
   240  		defer s.ipv6MsgsPool.Put(msgs)
   241  		for i := range bufs {
   242  			(*msgs)[i].Buffers[0] = bufs[i]
   243  		}
   244  		var numMsgs int
   245  		if runtime.GOOS == "linux" {
   246  			numMsgs, err = pc.ReadBatch(*msgs, 0)
   247  			if err != nil {
   248  				return 0, err
   249  			}
   250  		} else {
   251  			msg := &(*msgs)[0]
   252  			msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
   253  			if err != nil {
   254  				return 0, err
   255  			}
   256  			numMsgs = 1
   257  		}
   258  		for i := 0; i < numMsgs; i++ {
   259  			msg := &(*msgs)[i]
   260  			sizes[i] = msg.N
   261  			addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
   262  			ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
   263  			getSrcFromControl(msg.OOB[:msg.NN], ep)
   264  			eps[i] = ep
   265  		}
   266  		return numMsgs, nil
   267  	}
   268  }
   269  
   270  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
   271  // rename the IdealBatchSize constant to BatchSize.
   272  func (s *StdNetBind) BatchSize() int {
   273  	if runtime.GOOS == "linux" {
   274  		return IdealBatchSize
   275  	}
   276  	return 1
   277  }
   278  
   279  func (s *StdNetBind) Close() error {
   280  	s.mu.Lock()
   281  	defer s.mu.Unlock()
   282  
   283  	var err1, err2 error
   284  	if s.ipv4 != nil {
   285  		err1 = s.ipv4.Close()
   286  		s.ipv4 = nil
   287  		s.ipv4PC = nil
   288  	}
   289  	if s.ipv6 != nil {
   290  		err2 = s.ipv6.Close()
   291  		s.ipv6 = nil
   292  		s.ipv6PC = nil
   293  	}
   294  	s.blackhole4 = false
   295  	s.blackhole6 = false
   296  	if err1 != nil {
   297  		return err1
   298  	}
   299  	return err2
   300  }
   301  
   302  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
   303  	s.mu.Lock()
   304  	blackhole := s.blackhole4
   305  	conn := s.ipv4
   306  	var (
   307  		pc4 *ipv4.PacketConn
   308  		pc6 *ipv6.PacketConn
   309  	)
   310  	is6 := false
   311  	if endpoint.DstIP().Is6() {
   312  		blackhole = s.blackhole6
   313  		conn = s.ipv6
   314  		pc6 = s.ipv6PC
   315  		is6 = true
   316  	} else {
   317  		pc4 = s.ipv4PC
   318  	}
   319  	s.mu.Unlock()
   320  
   321  	if blackhole {
   322  		return nil
   323  	}
   324  	if conn == nil {
   325  		return syscall.EAFNOSUPPORT
   326  	}
   327  	if is6 {
   328  		return s.send6(conn, pc6, endpoint, bufs)
   329  	} else {
   330  		return s.send4(conn, pc4, endpoint, bufs)
   331  	}
   332  }
   333  
   334  func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error {
   335  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   336  	as4 := ep.DstIP().As4()
   337  	copy(ua.IP, as4[:])
   338  	ua.IP = ua.IP[:4]
   339  	ua.Port = int(ep.(*StdNetEndpoint).Port())
   340  	msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
   341  	for i, buf := range bufs {
   342  		(*msgs)[i].Buffers[0] = buf
   343  		(*msgs)[i].Addr = ua
   344  		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
   345  	}
   346  	var (
   347  		n     int
   348  		err   error
   349  		start int
   350  	)
   351  	if runtime.GOOS == "linux" {
   352  		for {
   353  			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
   354  			if err != nil || n == len((*msgs)[start:len(bufs)]) {
   355  				break
   356  			}
   357  			start += n
   358  		}
   359  	} else {
   360  		for i, buf := range bufs {
   361  			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
   362  			if err != nil {
   363  				break
   364  			}
   365  		}
   366  	}
   367  	s.udpAddrPool.Put(ua)
   368  	s.ipv4MsgsPool.Put(msgs)
   369  	return err
   370  }
   371  
   372  func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error {
   373  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
   374  	as16 := ep.DstIP().As16()
   375  	copy(ua.IP, as16[:])
   376  	ua.IP = ua.IP[:16]
   377  	ua.Port = int(ep.(*StdNetEndpoint).Port())
   378  	msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
   379  	for i, buf := range bufs {
   380  		(*msgs)[i].Buffers[0] = buf
   381  		(*msgs)[i].Addr = ua
   382  		setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
   383  	}
   384  	var (
   385  		n     int
   386  		err   error
   387  		start int
   388  	)
   389  	if runtime.GOOS == "linux" {
   390  		for {
   391  			n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0)
   392  			if err != nil || n == len((*msgs)[start:len(bufs)]) {
   393  				break
   394  			}
   395  			start += n
   396  		}
   397  	} else {
   398  		for i, buf := range bufs {
   399  			_, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua)
   400  			if err != nil {
   401  				break
   402  			}
   403  		}
   404  	}
   405  	s.udpAddrPool.Put(ua)
   406  	s.ipv6MsgsPool.Put(msgs)
   407  	return err
   408  }