github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"net"
    11  	"sync"
    12  	"syscall"
    13  
    14  	"golang.zx2c4.com/go118/netip"
    15  )
    16  
    17  // StdNetBind is meant to be a temporary solution on platforms for which
    18  // the sticky socket / source caching behavior has not yet been implemented.
    19  // It uses the Go's net package to implement networking.
    20  // See LinuxSocketBind for a proper implementation on the Linux platform.
    21  type StdNetBind struct {
    22  	mu         sync.Mutex // protects following fields
    23  	ipv4       *net.UDPConn
    24  	ipv6       *net.UDPConn
    25  	blackhole4 bool
    26  	blackhole6 bool
    27  }
    28  
    29  func NewStdNetBind() Bind { return &StdNetBind{} }
    30  
    31  type StdNetEndpoint net.UDPAddr
    32  
    33  var (
    34  	_ Bind     = (*StdNetBind)(nil)
    35  	_ Endpoint = (*StdNetEndpoint)(nil)
    36  )
    37  
    38  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    39  	e, err := netip.ParseAddrPort(s)
    40  	return (*StdNetEndpoint)(&net.UDPAddr{
    41  		IP:   e.Addr().AsSlice(),
    42  		Port: int(e.Port()),
    43  		Zone: e.Addr().Zone(),
    44  	}), err
    45  }
    46  
    47  func (*StdNetEndpoint) ClearSrc() {}
    48  
    49  func (e *StdNetEndpoint) DstIP() netip.Addr {
    50  	a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP)
    51  	return a
    52  }
    53  
    54  func (e *StdNetEndpoint) SrcIP() netip.Addr {
    55  	return netip.Addr{} // not supported
    56  }
    57  
    58  func (e *StdNetEndpoint) DstToBytes() []byte {
    59  	addr := (*net.UDPAddr)(e)
    60  	out := addr.IP.To4()
    61  	if out == nil {
    62  		out = addr.IP
    63  	}
    64  	out = append(out, byte(addr.Port&0xff))
    65  	out = append(out, byte((addr.Port>>8)&0xff))
    66  	return out
    67  }
    68  
    69  func (e *StdNetEndpoint) DstToString() string {
    70  	return (*net.UDPAddr)(e).String()
    71  }
    72  
    73  func (e *StdNetEndpoint) SrcToString() string {
    74  	return ""
    75  }
    76  
    77  func listenNet(network string, port int) (*net.UDPConn, int, error) {
    78  	conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
    79  	if err != nil {
    80  		return nil, 0, err
    81  	}
    82  
    83  	// Retrieve port.
    84  	laddr := conn.LocalAddr()
    85  	uaddr, err := net.ResolveUDPAddr(
    86  		laddr.Network(),
    87  		laddr.String(),
    88  	)
    89  	if err != nil {
    90  		return nil, 0, err
    91  	}
    92  	return conn, uaddr.Port, nil
    93  }
    94  
    95  func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
    96  	bind.mu.Lock()
    97  	defer bind.mu.Unlock()
    98  
    99  	var err error
   100  	var tries int
   101  
   102  	if bind.ipv4 != nil || bind.ipv6 != nil {
   103  		return nil, 0, ErrBindAlreadyOpen
   104  	}
   105  
   106  	// Attempt to open ipv4 and ipv6 listeners on the same port.
   107  	// If uport is 0, we can retry on failure.
   108  again:
   109  	port := int(uport)
   110  	var ipv4, ipv6 *net.UDPConn
   111  
   112  	ipv4, port, err = listenNet("udp4", port)
   113  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   114  		return nil, 0, err
   115  	}
   116  
   117  	// Listen on the same port as we're using for ipv4.
   118  	ipv6, port, err = listenNet("udp6", port)
   119  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   120  		ipv4.Close()
   121  		tries++
   122  		goto again
   123  	}
   124  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   125  		ipv4.Close()
   126  		return nil, 0, err
   127  	}
   128  	var fns []ReceiveFunc
   129  	if ipv4 != nil {
   130  		fns = append(fns, bind.makeReceiveIPv4(ipv4))
   131  		bind.ipv4 = ipv4
   132  	}
   133  	if ipv6 != nil {
   134  		fns = append(fns, bind.makeReceiveIPv6(ipv6))
   135  		bind.ipv6 = ipv6
   136  	}
   137  	if len(fns) == 0 {
   138  		return nil, 0, syscall.EAFNOSUPPORT
   139  	}
   140  	return fns, uint16(port), nil
   141  }
   142  
   143  func (bind *StdNetBind) Close() error {
   144  	bind.mu.Lock()
   145  	defer bind.mu.Unlock()
   146  
   147  	var err1, err2 error
   148  	if bind.ipv4 != nil {
   149  		err1 = bind.ipv4.Close()
   150  		bind.ipv4 = nil
   151  	}
   152  	if bind.ipv6 != nil {
   153  		err2 = bind.ipv6.Close()
   154  		bind.ipv6 = nil
   155  	}
   156  	bind.blackhole4 = false
   157  	bind.blackhole6 = false
   158  	if err1 != nil {
   159  		return err1
   160  	}
   161  	return err2
   162  }
   163  
   164  func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
   165  	return func(buff []byte) (int, Endpoint, error) {
   166  		n, endpoint, err := conn.ReadFromUDP(buff)
   167  		if endpoint != nil {
   168  			endpoint.IP = endpoint.IP.To4()
   169  		}
   170  		return n, (*StdNetEndpoint)(endpoint), err
   171  	}
   172  }
   173  
   174  func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
   175  	return func(buff []byte) (int, Endpoint, error) {
   176  		n, endpoint, err := conn.ReadFromUDP(buff)
   177  		return n, (*StdNetEndpoint)(endpoint), err
   178  	}
   179  }
   180  
   181  func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
   182  	var err error
   183  	nend, ok := endpoint.(*StdNetEndpoint)
   184  	if !ok {
   185  		return ErrWrongEndpointType
   186  	}
   187  
   188  	bind.mu.Lock()
   189  	blackhole := bind.blackhole4
   190  	conn := bind.ipv4
   191  	if nend.IP.To4() == nil {
   192  		blackhole = bind.blackhole6
   193  		conn = bind.ipv6
   194  	}
   195  	bind.mu.Unlock()
   196  
   197  	if blackhole {
   198  		return nil
   199  	}
   200  	if conn == nil {
   201  		return syscall.EAFNOSUPPORT
   202  	}
   203  	_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
   204  	return err
   205  }