github.com/cawidtu/notwireguard-go/conn@v0.0.0-20230523131112-68e8e5ce9cdf/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "net" 11 "net/netip" 12 "sync" 13 "syscall" 14 ) 15 16 // StdNetBind is meant to be a temporary solution on platforms for which 17 // the sticky socket / source caching behavior has not yet been implemented. 18 // It uses the Go's net package to implement networking. 19 // See LinuxSocketBind for a proper implementation on the Linux platform. 20 type StdNetBind struct { 21 mu sync.Mutex // protects following fields 22 ipv4 *net.UDPConn 23 ipv6 *net.UDPConn 24 blackhole4 bool 25 blackhole6 bool 26 } 27 28 func NewStdNetBind() Bind { return &StdNetBind{} } 29 30 type StdNetEndpoint netip.AddrPort 31 32 var ( 33 _ Bind = (*StdNetBind)(nil) 34 _ Endpoint = StdNetEndpoint{} 35 ) 36 37 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 38 e, err := netip.ParseAddrPort(s) 39 return asEndpoint(e), err 40 } 41 42 func (StdNetEndpoint) ClearSrc() {} 43 44 func (e StdNetEndpoint) DstIP() netip.Addr { 45 return (netip.AddrPort)(e).Addr() 46 } 47 48 func (e StdNetEndpoint) SrcIP() netip.Addr { 49 return netip.Addr{} // not supported 50 } 51 52 func (e StdNetEndpoint) DstToBytes() []byte { 53 b, _ := (netip.AddrPort)(e).MarshalBinary() 54 return b 55 } 56 57 func (e StdNetEndpoint) DstToString() string { 58 return (netip.AddrPort)(e).String() 59 } 60 61 func (e StdNetEndpoint) SrcToString() string { 62 return "" 63 } 64 65 func listenNet(network string, port int) (*net.UDPConn, int, error) { 66 conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 67 if err != nil { 68 return nil, 0, err 69 } 70 71 // Retrieve port. 72 laddr := conn.LocalAddr() 73 uaddr, err := net.ResolveUDPAddr( 74 laddr.Network(), 75 laddr.String(), 76 ) 77 if err != nil { 78 return nil, 0, err 79 } 80 return conn, uaddr.Port, nil 81 } 82 83 func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 84 bind.mu.Lock() 85 defer bind.mu.Unlock() 86 87 var err error 88 var tries int 89 90 if bind.ipv4 != nil || bind.ipv6 != nil { 91 return nil, 0, ErrBindAlreadyOpen 92 } 93 94 // Attempt to open ipv4 and ipv6 listeners on the same port. 95 // If uport is 0, we can retry on failure. 96 again: 97 port := int(uport) 98 var ipv4, ipv6 *net.UDPConn 99 100 ipv4, port, err = listenNet("udp4", port) 101 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 102 return nil, 0, err 103 } 104 105 // Listen on the same port as we're using for ipv4. 106 ipv6, port, err = listenNet("udp6", port) 107 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 108 ipv4.Close() 109 tries++ 110 goto again 111 } 112 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 113 ipv4.Close() 114 return nil, 0, err 115 } 116 var fns []ReceiveFunc 117 if ipv4 != nil { 118 fns = append(fns, bind.makeReceiveIPv4(ipv4)) 119 bind.ipv4 = ipv4 120 } 121 if ipv6 != nil { 122 fns = append(fns, bind.makeReceiveIPv6(ipv6)) 123 bind.ipv6 = ipv6 124 } 125 if len(fns) == 0 { 126 return nil, 0, syscall.EAFNOSUPPORT 127 } 128 return fns, uint16(port), nil 129 } 130 131 func (bind *StdNetBind) Close() error { 132 bind.mu.Lock() 133 defer bind.mu.Unlock() 134 135 var err1, err2 error 136 if bind.ipv4 != nil { 137 err1 = bind.ipv4.Close() 138 bind.ipv4 = nil 139 } 140 if bind.ipv6 != nil { 141 err2 = bind.ipv6.Close() 142 bind.ipv6 = nil 143 } 144 bind.blackhole4 = false 145 bind.blackhole6 = false 146 if err1 != nil { 147 return err1 148 } 149 return err2 150 } 151 152 func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 153 return func(buff []byte) (int, Endpoint, error) { 154 n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 155 return n, asEndpoint(endpoint), err 156 } 157 } 158 159 func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 160 return func(buff []byte) (int, Endpoint, error) { 161 n, endpoint, err := conn.ReadFromUDPAddrPort(buff) 162 return n, asEndpoint(endpoint), err 163 } 164 } 165 166 func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 167 var err error 168 nend, ok := endpoint.(StdNetEndpoint) 169 if !ok { 170 return ErrWrongEndpointType 171 } 172 addrPort := netip.AddrPort(nend) 173 174 bind.mu.Lock() 175 blackhole := bind.blackhole4 176 conn := bind.ipv4 177 if addrPort.Addr().Is6() { 178 blackhole = bind.blackhole6 179 conn = bind.ipv6 180 } 181 bind.mu.Unlock() 182 183 if blackhole { 184 return nil 185 } 186 if conn == nil { 187 return syscall.EAFNOSUPPORT 188 } 189 _, err = conn.WriteToUDPAddrPort(buff, addrPort) 190 return err 191 } 192 193 // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. 194 // This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, 195 // but Endpoints are immutable, so we can re-use them. 196 var endpointPool = sync.Pool{ 197 New: func() any { 198 return make(map[netip.AddrPort]Endpoint) 199 }, 200 } 201 202 // asEndpoint returns an Endpoint containing ap. 203 func asEndpoint(ap netip.AddrPort) Endpoint { 204 m := endpointPool.Get().(map[netip.AddrPort]Endpoint) 205 defer endpointPool.Put(m) 206 e, ok := m[ap] 207 if !ok { 208 e = Endpoint(StdNetEndpoint(ap)) 209 m[ap] = e 210 } 211 return e 212 }