github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"net"
    11  	"net/netip"
    12  	"sync"
    13  	"syscall"
    14  )
    15  
    16  // StdNetBind is meant to be a temporary solution on platforms for which
    17  // the sticky socket / source caching behavior has not yet been implemented.
    18  // It uses the Go's net package to implement networking.
    19  // See LinuxSocketBind for a proper implementation on the Linux platform.
    20  type StdNetBind struct {
    21  	mu         sync.Mutex // protects following fields
    22  	ipv4       *net.UDPConn
    23  	ipv6       *net.UDPConn
    24  	blackhole4 bool
    25  	blackhole6 bool
    26  }
    27  
    28  func NewStdNetBind() Bind { return &StdNetBind{} }
    29  
    30  type StdNetEndpoint netip.AddrPort
    31  
    32  var (
    33  	_ Bind     = (*StdNetBind)(nil)
    34  	_ Endpoint = StdNetEndpoint{}
    35  )
    36  
    37  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    38  	e, err := netip.ParseAddrPort(s)
    39  	return asEndpoint(e), err
    40  }
    41  
    42  func (StdNetEndpoint) ClearSrc() {}
    43  
    44  func (e StdNetEndpoint) DstIP() netip.Addr {
    45  	return (netip.AddrPort)(e).Addr()
    46  }
    47  
    48  func (e StdNetEndpoint) SrcIP() netip.Addr {
    49  	return netip.Addr{} // not supported
    50  }
    51  
    52  func (e StdNetEndpoint) DstToBytes() []byte {
    53  	b, _ := (netip.AddrPort)(e).MarshalBinary()
    54  	return b
    55  }
    56  
    57  func (e StdNetEndpoint) DstToString() string {
    58  	return (netip.AddrPort)(e).String()
    59  }
    60  
    61  func (e StdNetEndpoint) SrcToString() string {
    62  	return ""
    63  }
    64  
    65  func listenNet(network string, port int) (*net.UDPConn, int, error) {
    66  	conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
    67  	if err != nil {
    68  		return nil, 0, err
    69  	}
    70  
    71  	// Retrieve port.
    72  	laddr := conn.LocalAddr()
    73  	uaddr, err := net.ResolveUDPAddr(
    74  		laddr.Network(),
    75  		laddr.String(),
    76  	)
    77  	if err != nil {
    78  		return nil, 0, err
    79  	}
    80  	return conn, uaddr.Port, nil
    81  }
    82  
    83  func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
    84  	bind.mu.Lock()
    85  	defer bind.mu.Unlock()
    86  
    87  	var err error
    88  	var tries int
    89  
    90  	if bind.ipv4 != nil || bind.ipv6 != nil {
    91  		return nil, 0, ErrBindAlreadyOpen
    92  	}
    93  
    94  	// Attempt to open ipv4 and ipv6 listeners on the same port.
    95  	// If uport is 0, we can retry on failure.
    96  again:
    97  	port := int(uport)
    98  	var ipv4, ipv6 *net.UDPConn
    99  
   100  	ipv4, port, err = listenNet("udp4", port)
   101  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   102  		return nil, 0, err
   103  	}
   104  
   105  	// Listen on the same port as we're using for ipv4.
   106  	ipv6, port, err = listenNet("udp6", port)
   107  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   108  		ipv4.Close()
   109  		tries++
   110  		goto again
   111  	}
   112  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   113  		ipv4.Close()
   114  		return nil, 0, err
   115  	}
   116  	var fns []ReceiveFunc
   117  	if ipv4 != nil {
   118  		fns = append(fns, bind.makeReceiveIPv4(ipv4))
   119  		bind.ipv4 = ipv4
   120  	}
   121  	if ipv6 != nil {
   122  		fns = append(fns, bind.makeReceiveIPv6(ipv6))
   123  		bind.ipv6 = ipv6
   124  	}
   125  	if len(fns) == 0 {
   126  		return nil, 0, syscall.EAFNOSUPPORT
   127  	}
   128  	return fns, uint16(port), nil
   129  }
   130  
   131  func (bind *StdNetBind) Close() error {
   132  	bind.mu.Lock()
   133  	defer bind.mu.Unlock()
   134  
   135  	var err1, err2 error
   136  	if bind.ipv4 != nil {
   137  		err1 = bind.ipv4.Close()
   138  		bind.ipv4 = nil
   139  	}
   140  	if bind.ipv6 != nil {
   141  		err2 = bind.ipv6.Close()
   142  		bind.ipv6 = nil
   143  	}
   144  	bind.blackhole4 = false
   145  	bind.blackhole6 = false
   146  	if err1 != nil {
   147  		return err1
   148  	}
   149  	return err2
   150  }
   151  
   152  func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
   153  	return func(buff []byte) (int, Endpoint, error) {
   154  		n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
   155  		return n, asEndpoint(endpoint), err
   156  	}
   157  }
   158  
   159  func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
   160  	return func(buff []byte) (int, Endpoint, error) {
   161  		n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
   162  		return n, asEndpoint(endpoint), err
   163  	}
   164  }
   165  
   166  func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
   167  	var err error
   168  	nend, ok := endpoint.(StdNetEndpoint)
   169  	if !ok {
   170  		return ErrWrongEndpointType
   171  	}
   172  	addrPort := netip.AddrPort(nend)
   173  
   174  	bind.mu.Lock()
   175  	blackhole := bind.blackhole4
   176  	conn := bind.ipv4
   177  	if addrPort.Addr().Is6() {
   178  		blackhole = bind.blackhole6
   179  		conn = bind.ipv6
   180  	}
   181  	bind.mu.Unlock()
   182  
   183  	if blackhole {
   184  		return nil
   185  	}
   186  	if conn == nil {
   187  		return syscall.EAFNOSUPPORT
   188  	}
   189  	_, err = conn.WriteToUDPAddrPort(buff, addrPort)
   190  	return err
   191  }
   192  
   193  // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
   194  // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
   195  // but Endpoints are immutable, so we can re-use them.
   196  var endpointPool = sync.Pool{
   197  	New: func() any {
   198  		return make(map[netip.AddrPort]Endpoint)
   199  	},
   200  }
   201  
   202  // asEndpoint returns an Endpoint containing ap.
   203  func asEndpoint(ap netip.AddrPort) Endpoint {
   204  	m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
   205  	defer endpointPool.Put(m)
   206  	e, ok := m[ap]
   207  	if !ok {
   208  		e = Endpoint(StdNetEndpoint(ap))
   209  		m[ap] = e
   210  	}
   211  	return e
   212  }