github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/bind_std.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conn
     7  
     8  import (
     9  	"errors"
    10  	"net"
    11  	"sync"
    12  	"syscall"
    13  )
    14  
    15  // StdNetBind is meant to be a temporary solution on platforms for which
    16  // the sticky socket / source caching behavior has not yet been implemented.
    17  // It uses the Go's net package to implement networking.
    18  // See LinuxSocketBind for a proper implementation on the Linux platform.
    19  type StdNetBind struct {
    20  	mu         sync.Mutex // protects following fields
    21  	ipv4       *net.UDPConn
    22  	ipv6       *net.UDPConn
    23  	blackhole4 bool
    24  	blackhole6 bool
    25  }
    26  
    27  func NewStdNetBind() Bind { return &StdNetBind{} }
    28  
    29  type StdNetEndpoint net.UDPAddr
    30  
    31  var _ Bind = (*StdNetBind)(nil)
    32  var _ Endpoint = (*StdNetEndpoint)(nil)
    33  
    34  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
    35  	addr, err := parseEndpoint(s)
    36  	return (*StdNetEndpoint)(addr), err
    37  }
    38  
    39  func (*StdNetEndpoint) ClearSrc() {}
    40  
    41  func (e *StdNetEndpoint) DstIP() net.IP {
    42  	return (*net.UDPAddr)(e).IP
    43  }
    44  
    45  func (e *StdNetEndpoint) SrcIP() net.IP {
    46  	return nil // not supported
    47  }
    48  
    49  func (e *StdNetEndpoint) DstToBytes() []byte {
    50  	addr := (*net.UDPAddr)(e)
    51  	out := addr.IP.To4()
    52  	if out == nil {
    53  		out = addr.IP
    54  	}
    55  	out = append(out, byte(addr.Port&0xff))
    56  	out = append(out, byte((addr.Port>>8)&0xff))
    57  	return out
    58  }
    59  
    60  func (e *StdNetEndpoint) DstToString() string {
    61  	return (*net.UDPAddr)(e).String()
    62  }
    63  
    64  func (e *StdNetEndpoint) SrcToString() string {
    65  	return ""
    66  }
    67  
    68  func listenNet(network string, port int) (*net.UDPConn, int, error) {
    69  	conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
    70  	if err != nil {
    71  		return nil, 0, err
    72  	}
    73  
    74  	// Retrieve port.
    75  	laddr := conn.LocalAddr()
    76  	uaddr, err := net.ResolveUDPAddr(
    77  		laddr.Network(),
    78  		laddr.String(),
    79  	)
    80  	if err != nil {
    81  		return nil, 0, err
    82  	}
    83  	return conn, uaddr.Port, nil
    84  }
    85  
    86  func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
    87  	bind.mu.Lock()
    88  	defer bind.mu.Unlock()
    89  
    90  	var err error
    91  	var tries int
    92  
    93  	if bind.ipv4 != nil || bind.ipv6 != nil {
    94  		return nil, 0, ErrBindAlreadyOpen
    95  	}
    96  
    97  	// Attempt to open ipv4 and ipv6 listeners on the same port.
    98  	// If uport is 0, we can retry on failure.
    99  again:
   100  	port := int(uport)
   101  	var ipv4, ipv6 *net.UDPConn
   102  
   103  	ipv4, port, err = listenNet("udp4", port)
   104  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   105  		return nil, 0, err
   106  	}
   107  
   108  	// Listen on the same port as we're using for ipv4.
   109  	ipv6, port, err = listenNet("udp6", port)
   110  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
   111  		ipv4.Close()
   112  		tries++
   113  		goto again
   114  	}
   115  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
   116  		ipv4.Close()
   117  		return nil, 0, err
   118  	}
   119  	var fns []ReceiveFunc
   120  	if ipv4 != nil {
   121  		fns = append(fns, bind.makeReceiveIPv4(ipv4))
   122  		bind.ipv4 = ipv4
   123  	}
   124  	if ipv6 != nil {
   125  		fns = append(fns, bind.makeReceiveIPv6(ipv6))
   126  		bind.ipv6 = ipv6
   127  	}
   128  	if len(fns) == 0 {
   129  		return nil, 0, syscall.EAFNOSUPPORT
   130  	}
   131  	return fns, uint16(port), nil
   132  }
   133  
   134  func (bind *StdNetBind) Close() error {
   135  	bind.mu.Lock()
   136  	defer bind.mu.Unlock()
   137  
   138  	var err1, err2 error
   139  	if bind.ipv4 != nil {
   140  		err1 = bind.ipv4.Close()
   141  		bind.ipv4 = nil
   142  	}
   143  	if bind.ipv6 != nil {
   144  		err2 = bind.ipv6.Close()
   145  		bind.ipv6 = nil
   146  	}
   147  	bind.blackhole4 = false
   148  	bind.blackhole6 = false
   149  	if err1 != nil {
   150  		return err1
   151  	}
   152  	return err2
   153  }
   154  
   155  func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
   156  	return func(buff []byte) (int, Endpoint, error) {
   157  		n, endpoint, err := conn.ReadFromUDP(buff)
   158  		if endpoint != nil {
   159  			endpoint.IP = endpoint.IP.To4()
   160  		}
   161  		return n, (*StdNetEndpoint)(endpoint), err
   162  	}
   163  }
   164  
   165  func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
   166  	return func(buff []byte) (int, Endpoint, error) {
   167  		n, endpoint, err := conn.ReadFromUDP(buff)
   168  		return n, (*StdNetEndpoint)(endpoint), err
   169  	}
   170  }
   171  
   172  func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
   173  	var err error
   174  	nend, ok := endpoint.(*StdNetEndpoint)
   175  	if !ok {
   176  		return ErrWrongEndpointType
   177  	}
   178  
   179  	bind.mu.Lock()
   180  	blackhole := bind.blackhole4
   181  	conn := bind.ipv4
   182  	if nend.IP.To4() == nil {
   183  		blackhole = bind.blackhole6
   184  		conn = bind.ipv6
   185  	}
   186  	bind.mu.Unlock()
   187  
   188  	if blackhole {
   189  		return nil
   190  	}
   191  	if conn == nil {
   192  		return syscall.EAFNOSUPPORT
   193  	}
   194  	_, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
   195  	return err
   196  }