github.com/xraypb/xray-core@v1.6.6/proxy/wireguard/bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "net/netip" 9 "strconv" 10 "sync" 11 12 "github.com/sagernet/wireguard-go/conn" 13 xnet "github.com/xraypb/xray-core/common/net" 14 "github.com/xraypb/xray-core/features/dns" 15 "github.com/xraypb/xray-core/transport/internet" 16 ) 17 18 type netReadInfo struct { 19 // status 20 waiter sync.WaitGroup 21 // param 22 buff []byte 23 // result 24 bytes int 25 endpoint conn.Endpoint 26 err error 27 } 28 29 type netBindClient struct { 30 workers int 31 dialer internet.Dialer 32 dns dns.Client 33 dnsOption dns.IPOption 34 35 readQueue chan *netReadInfo 36 } 37 38 func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { 39 ipStr, port, _, err := splitAddrPort(s) 40 if err != nil { 41 return nil, err 42 } 43 44 var addr net.IP 45 if IsDomainName(ipStr) { 46 ips, err := n.dns.LookupIP(ipStr, n.dnsOption) 47 if err != nil { 48 return nil, err 49 } else if len(ips) == 0 { 50 return nil, dns.ErrEmptyResponse 51 } 52 addr = ips[0] 53 } else { 54 addr = net.ParseIP(ipStr) 55 } 56 if addr == nil { 57 return nil, errors.New("failed to parse ip: " + ipStr) 58 } 59 60 var ip xnet.Address 61 if p4 := addr.To4(); len(p4) == net.IPv4len { 62 ip = xnet.IPAddress(p4[:]) 63 } else { 64 ip = xnet.IPAddress(addr[:]) 65 } 66 67 dst := xnet.Destination{ 68 Address: ip, 69 Port: xnet.Port(port), 70 Network: xnet.Network_UDP, 71 } 72 73 return &netEndpoint{ 74 dst: dst, 75 }, nil 76 } 77 78 func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { 79 bind.readQueue = make(chan *netReadInfo) 80 81 fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) { 82 defer func() { 83 if r := recover(); r != nil { 84 cap = 0 85 ep = nil 86 err = errors.New("channel closed") 87 } 88 }() 89 90 r := &netReadInfo{ 91 buff: buff, 92 } 93 r.waiter.Add(1) 94 bind.readQueue <- r 95 r.waiter.Wait() // wait read goroutine done, or we will miss the result 96 return r.bytes, r.endpoint, r.err 97 } 98 workers := bind.workers 99 if workers <= 0 { 100 workers = 1 101 } 102 arr := make([]conn.ReceiveFunc, workers) 103 for i := 0; i < workers; i++ { 104 arr[i] = fun 105 } 106 107 return arr, uint16(uport), nil 108 } 109 110 func (bind *netBindClient) Close() error { 111 if bind.readQueue != nil { 112 close(bind.readQueue) 113 } 114 return nil 115 } 116 117 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { 118 c, err := bind.dialer.Dial(context.Background(), endpoint.dst) 119 if err != nil { 120 return err 121 } 122 endpoint.conn = c 123 124 go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { 125 for { 126 v, ok := <-readQueue 127 if !ok { 128 return 129 } 130 i, err := c.Read(v.buff) 131 v.bytes = i 132 v.endpoint = endpoint 133 v.err = err 134 v.waiter.Done() 135 if err != nil && errors.Is(err, io.EOF) { 136 endpoint.conn = nil 137 return 138 } 139 } 140 }(bind.readQueue, endpoint) 141 142 return nil 143 } 144 145 func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { 146 var err error 147 148 nend, ok := endpoint.(*netEndpoint) 149 if !ok { 150 return conn.ErrWrongEndpointType 151 } 152 153 if nend.conn == nil { 154 err = bind.connectTo(nend) 155 if err != nil { 156 return err 157 } 158 } 159 160 _, err = nend.conn.Write(buff) 161 162 return err 163 } 164 165 func (bind *netBindClient) SetMark(mark uint32) error { 166 return nil 167 } 168 169 type netEndpoint struct { 170 dst xnet.Destination 171 conn net.Conn 172 } 173 174 func (netEndpoint) ClearSrc() {} 175 176 func (e netEndpoint) DstIP() netip.Addr { 177 return toNetIpAddr(e.dst.Address) 178 } 179 180 func (e netEndpoint) SrcIP() netip.Addr { 181 return netip.Addr{} 182 } 183 184 func (e netEndpoint) DstToBytes() []byte { 185 var dat []byte 186 if e.dst.Address.Family().IsIPv4() { 187 dat = e.dst.Address.IP().To4()[:] 188 } else { 189 dat = e.dst.Address.IP().To16()[:] 190 } 191 dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) 192 return dat 193 } 194 195 func (e netEndpoint) DstToString() string { 196 return e.dst.NetAddr() 197 } 198 199 func (e netEndpoint) SrcToString() string { 200 return "" 201 } 202 203 func toNetIpAddr(addr xnet.Address) netip.Addr { 204 if addr.Family().IsIPv4() { 205 ip := addr.IP() 206 return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) 207 } else { 208 ip := addr.IP() 209 arr := [16]byte{} 210 for i := 0; i < 16; i++ { 211 arr[i] = ip[i] 212 } 213 return netip.AddrFrom16(arr) 214 } 215 } 216 217 func stringsLastIndexByte(s string, b byte) int { 218 for i := len(s) - 1; i >= 0; i-- { 219 if s[i] == b { 220 return i 221 } 222 } 223 return -1 224 } 225 226 func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { 227 i := stringsLastIndexByte(s, ':') 228 if i == -1 { 229 return "", 0, false, errors.New("not an ip:port") 230 } 231 232 ip = s[:i] 233 portStr := s[i+1:] 234 if len(ip) == 0 { 235 return "", 0, false, errors.New("no IP") 236 } 237 if len(portStr) == 0 { 238 return "", 0, false, errors.New("no port") 239 } 240 port64, err := strconv.ParseUint(portStr, 10, 16) 241 if err != nil { 242 return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) 243 } 244 port = uint16(port64) 245 if ip[0] == '[' { 246 if len(ip) < 2 || ip[len(ip)-1] != ']' { 247 return "", 0, false, errors.New("missing ]") 248 } 249 ip = ip[1 : len(ip)-1] 250 v6 = true 251 } 252 253 return ip, port, v6, nil 254 }