github.com/liloew/wireguard-go@v0.0.0-20220224014633-9cd745e6f114/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "net" 11 "sync" 12 "syscall" 13 14 "golang.zx2c4.com/go118/netip" 15 ) 16 17 // StdNetBind is meant to be a temporary solution on platforms for which 18 // the sticky socket / source caching behavior has not yet been implemented. 19 // It uses the Go's net package to implement networking. 20 // See LinuxSocketBind for a proper implementation on the Linux platform. 21 type StdNetBind struct { 22 mu sync.Mutex // protects following fields 23 ipv4 *net.UDPConn 24 ipv6 *net.UDPConn 25 blackhole4 bool 26 blackhole6 bool 27 } 28 29 func NewStdNetBind() Bind { return &StdNetBind{} } 30 31 type StdNetEndpoint net.UDPAddr 32 33 var ( 34 _ Bind = (*StdNetBind)(nil) 35 _ Endpoint = (*StdNetEndpoint)(nil) 36 ) 37 38 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 39 e, err := netip.ParseAddrPort(s) 40 return (*StdNetEndpoint)(&net.UDPAddr{ 41 IP: e.Addr().AsSlice(), 42 Port: int(e.Port()), 43 Zone: e.Addr().Zone(), 44 }), err 45 } 46 47 func (*StdNetEndpoint) ClearSrc() {} 48 49 func (e *StdNetEndpoint) DstIP() netip.Addr { 50 a, _ := netip.AddrFromSlice((*net.UDPAddr)(e).IP) 51 return a 52 } 53 54 func (e *StdNetEndpoint) SrcIP() netip.Addr { 55 return netip.Addr{} // not supported 56 } 57 58 func (e *StdNetEndpoint) DstToBytes() []byte { 59 addr := (*net.UDPAddr)(e) 60 out := addr.IP.To4() 61 if out == nil { 62 out = addr.IP 63 } 64 out = append(out, byte(addr.Port&0xff)) 65 out = append(out, byte((addr.Port>>8)&0xff)) 66 return out 67 } 68 69 func (e *StdNetEndpoint) DstToString() string { 70 return (*net.UDPAddr)(e).String() 71 } 72 73 func (e *StdNetEndpoint) SrcToString() string { 74 return "" 75 } 76 77 func listenNet(network string, port int) (*net.UDPConn, int, error) { 78 conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 79 if err != nil { 80 return nil, 0, err 81 } 82 83 // Retrieve port. 84 laddr := conn.LocalAddr() 85 uaddr, err := net.ResolveUDPAddr( 86 laddr.Network(), 87 laddr.String(), 88 ) 89 if err != nil { 90 return nil, 0, err 91 } 92 return conn, uaddr.Port, nil 93 } 94 95 func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 96 bind.mu.Lock() 97 defer bind.mu.Unlock() 98 99 var err error 100 var tries int 101 102 if bind.ipv4 != nil || bind.ipv6 != nil { 103 return nil, 0, ErrBindAlreadyOpen 104 } 105 106 // Attempt to open ipv4 and ipv6 listeners on the same port. 107 // If uport is 0, we can retry on failure. 108 again: 109 port := int(uport) 110 var ipv4, ipv6 *net.UDPConn 111 112 ipv4, port, err = listenNet("udp4", port) 113 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 114 return nil, 0, err 115 } 116 117 // Listen on the same port as we're using for ipv4. 118 ipv6, port, err = listenNet("udp6", port) 119 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 120 ipv4.Close() 121 tries++ 122 goto again 123 } 124 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 125 ipv4.Close() 126 return nil, 0, err 127 } 128 var fns []ReceiveFunc 129 if ipv4 != nil { 130 fns = append(fns, bind.makeReceiveIPv4(ipv4)) 131 bind.ipv4 = ipv4 132 } 133 if ipv6 != nil { 134 fns = append(fns, bind.makeReceiveIPv6(ipv6)) 135 bind.ipv6 = ipv6 136 } 137 if len(fns) == 0 { 138 return nil, 0, syscall.EAFNOSUPPORT 139 } 140 return fns, uint16(port), nil 141 } 142 143 func (bind *StdNetBind) Close() error { 144 bind.mu.Lock() 145 defer bind.mu.Unlock() 146 147 var err1, err2 error 148 if bind.ipv4 != nil { 149 err1 = bind.ipv4.Close() 150 bind.ipv4 = nil 151 } 152 if bind.ipv6 != nil { 153 err2 = bind.ipv6.Close() 154 bind.ipv6 = nil 155 } 156 bind.blackhole4 = false 157 bind.blackhole6 = false 158 if err1 != nil { 159 return err1 160 } 161 return err2 162 } 163 164 func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 165 return func(buff []byte) (int, Endpoint, error) { 166 n, endpoint, err := conn.ReadFromUDP(buff) 167 if endpoint != nil { 168 endpoint.IP = endpoint.IP.To4() 169 } 170 return n, (*StdNetEndpoint)(endpoint), err 171 } 172 } 173 174 func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 175 return func(buff []byte) (int, Endpoint, error) { 176 n, endpoint, err := conn.ReadFromUDP(buff) 177 return n, (*StdNetEndpoint)(endpoint), err 178 } 179 } 180 181 func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 182 var err error 183 nend, ok := endpoint.(*StdNetEndpoint) 184 if !ok { 185 return ErrWrongEndpointType 186 } 187 188 bind.mu.Lock() 189 blackhole := bind.blackhole4 190 conn := bind.ipv4 191 if nend.IP.To4() == nil { 192 blackhole = bind.blackhole6 193 conn = bind.ipv6 194 } 195 bind.mu.Unlock() 196 197 if blackhole { 198 return nil 199 } 200 if conn == nil { 201 return syscall.EAFNOSUPPORT 202 } 203 _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend)) 204 return err 205 }