github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/conn/bind_std.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package conn 7 8 import ( 9 "errors" 10 "net" 11 "sync" 12 "syscall" 13 ) 14 15 // StdNetBind is meant to be a temporary solution on platforms for which 16 // the sticky socket / source caching behavior has not yet been implemented. 17 // It uses the Go's net package to implement networking. 18 // See LinuxSocketBind for a proper implementation on the Linux platform. 19 type StdNetBind struct { 20 mu sync.Mutex // protects following fields 21 ipv4 *net.UDPConn 22 ipv6 *net.UDPConn 23 blackhole4 bool 24 blackhole6 bool 25 } 26 27 func NewStdNetBind() Bind { return &StdNetBind{} } 28 29 type StdNetEndpoint net.UDPAddr 30 31 var _ Bind = (*StdNetBind)(nil) 32 var _ Endpoint = (*StdNetEndpoint)(nil) 33 34 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { 35 addr, err := parseEndpoint(s) 36 return (*StdNetEndpoint)(addr), err 37 } 38 39 func (*StdNetEndpoint) ClearSrc() {} 40 41 func (e *StdNetEndpoint) DstIP() net.IP { 42 return (*net.UDPAddr)(e).IP 43 } 44 45 func (e *StdNetEndpoint) SrcIP() net.IP { 46 return nil // not supported 47 } 48 49 func (e *StdNetEndpoint) DstToBytes() []byte { 50 addr := (*net.UDPAddr)(e) 51 out := addr.IP.To4() 52 if out == nil { 53 out = addr.IP 54 } 55 out = append(out, byte(addr.Port&0xff)) 56 out = append(out, byte((addr.Port>>8)&0xff)) 57 return out 58 } 59 60 func (e *StdNetEndpoint) DstToString() string { 61 return (*net.UDPAddr)(e).String() 62 } 63 64 func (e *StdNetEndpoint) SrcToString() string { 65 return "" 66 } 67 68 func listenNet(network string, port int) (*net.UDPConn, int, error) { 69 conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) 70 if err != nil { 71 return nil, 0, err 72 } 73 74 // Retrieve port. 75 laddr := conn.LocalAddr() 76 uaddr, err := net.ResolveUDPAddr( 77 laddr.Network(), 78 laddr.String(), 79 ) 80 if err != nil { 81 return nil, 0, err 82 } 83 return conn, uaddr.Port, nil 84 } 85 86 func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { 87 bind.mu.Lock() 88 defer bind.mu.Unlock() 89 90 var err error 91 var tries int 92 93 if bind.ipv4 != nil || bind.ipv6 != nil { 94 return nil, 0, ErrBindAlreadyOpen 95 } 96 97 // Attempt to open ipv4 and ipv6 listeners on the same port. 98 // If uport is 0, we can retry on failure. 99 again: 100 port := int(uport) 101 var ipv4, ipv6 *net.UDPConn 102 103 ipv4, port, err = listenNet("udp4", port) 104 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 105 return nil, 0, err 106 } 107 108 // Listen on the same port as we're using for ipv4. 109 ipv6, port, err = listenNet("udp6", port) 110 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { 111 ipv4.Close() 112 tries++ 113 goto again 114 } 115 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { 116 ipv4.Close() 117 return nil, 0, err 118 } 119 var fns []ReceiveFunc 120 if ipv4 != nil { 121 fns = append(fns, bind.makeReceiveIPv4(ipv4)) 122 bind.ipv4 = ipv4 123 } 124 if ipv6 != nil { 125 fns = append(fns, bind.makeReceiveIPv6(ipv6)) 126 bind.ipv6 = ipv6 127 } 128 if len(fns) == 0 { 129 return nil, 0, syscall.EAFNOSUPPORT 130 } 131 return fns, uint16(port), nil 132 } 133 134 func (bind *StdNetBind) Close() error { 135 bind.mu.Lock() 136 defer bind.mu.Unlock() 137 138 var err1, err2 error 139 if bind.ipv4 != nil { 140 err1 = bind.ipv4.Close() 141 bind.ipv4 = nil 142 } 143 if bind.ipv6 != nil { 144 err2 = bind.ipv6.Close() 145 bind.ipv6 = nil 146 } 147 bind.blackhole4 = false 148 bind.blackhole6 = false 149 if err1 != nil { 150 return err1 151 } 152 return err2 153 } 154 155 func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc { 156 return func(buff []byte) (int, Endpoint, error) { 157 n, endpoint, err := conn.ReadFromUDP(buff) 158 if endpoint != nil { 159 endpoint.IP = endpoint.IP.To4() 160 } 161 return n, (*StdNetEndpoint)(endpoint), err 162 } 163 } 164 165 func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc { 166 return func(buff []byte) (int, Endpoint, error) { 167 n, endpoint, err := conn.ReadFromUDP(buff) 168 return n, (*StdNetEndpoint)(endpoint), err 169 } 170 } 171 172 func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { 173 var err error 174 nend, ok := endpoint.(*StdNetEndpoint) 175 if !ok { 176 return ErrWrongEndpointType 177 } 178 179 bind.mu.Lock() 180 blackhole := bind.blackhole4 181 conn := bind.ipv4 182 if nend.IP.To4() == nil { 183 blackhole = bind.blackhole6 184 conn = bind.ipv6 185 } 186 bind.mu.Unlock() 187 188 if blackhole { 189 return nil 190 } 191 if conn == nil { 192 return syscall.EAFNOSUPPORT 193 } 194 _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend)) 195 return err 196 }