github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "net/netip" 9 "strconv" 10 "sync" 11 12 "github.com/sagernet/wireguard-go/conn" 13 xnet "github.com/xraypb/Xray-core/common/net" 14 "github.com/xraypb/Xray-core/features/dns" 15 "github.com/xraypb/Xray-core/transport/internet" 16 ) 17 18 type netReadInfo struct { 19 // status 20 waiter sync.WaitGroup 21 // param 22 buff []byte 23 // result 24 bytes int 25 endpoint conn.Endpoint 26 err error 27 } 28 29 type netBindClient struct { 30 workers int 31 dialer internet.Dialer 32 dns dns.Client 33 dnsOption dns.IPOption 34 reserved []byte 35 36 readQueue chan *netReadInfo 37 } 38 39 func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { 40 ipStr, port, _, err := splitAddrPort(s) 41 if err != nil { 42 return nil, err 43 } 44 45 var addr net.IP 46 if IsDomainName(ipStr) { 47 ips, err := n.dns.LookupIP(ipStr, n.dnsOption) 48 if err != nil { 49 return nil, err 50 } else if len(ips) == 0 { 51 return nil, dns.ErrEmptyResponse 52 } 53 addr = ips[0] 54 } else { 55 addr = net.ParseIP(ipStr) 56 } 57 if addr == nil { 58 return nil, errors.New("failed to parse ip: " + ipStr) 59 } 60 61 var ip xnet.Address 62 if p4 := addr.To4(); len(p4) == net.IPv4len { 63 ip = xnet.IPAddress(p4[:]) 64 } else { 65 ip = xnet.IPAddress(addr[:]) 66 } 67 68 dst := xnet.Destination{ 69 Address: ip, 70 Port: xnet.Port(port), 71 Network: xnet.Network_UDP, 72 } 73 74 return &netEndpoint{ 75 dst: dst, 76 }, nil 77 } 78 79 func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { 80 bind.readQueue = make(chan *netReadInfo) 81 82 fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) { 83 defer func() { 84 if r := recover(); r != nil { 85 cap = 0 86 ep = nil 87 err = errors.New("channel closed") 88 } 89 }() 90 91 r := &netReadInfo{ 92 buff: buff, 93 } 94 r.waiter.Add(1) 95 bind.readQueue <- r 96 r.waiter.Wait() // wait read goroutine done, or we will miss the result 97 return r.bytes, r.endpoint, r.err 98 } 99 workers := bind.workers 100 if workers <= 0 { 101 workers = 1 102 } 103 arr := make([]conn.ReceiveFunc, workers) 104 for i := 0; i < workers; i++ { 105 arr[i] = fun 106 } 107 108 return arr, uint16(uport), nil 109 } 110 111 func (bind *netBindClient) Close() error { 112 if bind.readQueue != nil { 113 close(bind.readQueue) 114 } 115 return nil 116 } 117 118 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { 119 c, err := bind.dialer.Dial(context.Background(), endpoint.dst) 120 if err != nil { 121 return err 122 } 123 endpoint.conn = c 124 125 go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { 126 for { 127 v, ok := <-readQueue 128 if !ok { 129 return 130 } 131 i, err := c.Read(v.buff) 132 133 if i > 3 { 134 v.buff[1] = 0 135 v.buff[2] = 0 136 v.buff[3] = 0 137 } 138 139 v.bytes = i 140 v.endpoint = endpoint 141 v.err = err 142 v.waiter.Done() 143 if err != nil && errors.Is(err, io.EOF) { 144 endpoint.conn = nil 145 return 146 } 147 } 148 }(bind.readQueue, endpoint) 149 150 return nil 151 } 152 153 func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error { 154 var err error 155 156 nend, ok := endpoint.(*netEndpoint) 157 if !ok { 158 return conn.ErrWrongEndpointType 159 } 160 161 if nend.conn == nil { 162 err = bind.connectTo(nend) 163 if err != nil { 164 return err 165 } 166 } 167 168 if len(buff) > 3 && len(bind.reserved) == 3 { 169 copy(buff[1:], bind.reserved) 170 } 171 172 _, err = nend.conn.Write(buff) 173 174 return err 175 } 176 177 func (bind *netBindClient) SetMark(mark uint32) error { 178 return nil 179 } 180 181 type netEndpoint struct { 182 dst xnet.Destination 183 conn net.Conn 184 } 185 186 func (netEndpoint) ClearSrc() {} 187 188 func (e netEndpoint) DstIP() netip.Addr { 189 return toNetIpAddr(e.dst.Address) 190 } 191 192 func (e netEndpoint) SrcIP() netip.Addr { 193 return netip.Addr{} 194 } 195 196 func (e netEndpoint) DstToBytes() []byte { 197 var dat []byte 198 if e.dst.Address.Family().IsIPv4() { 199 dat = e.dst.Address.IP().To4()[:] 200 } else { 201 dat = e.dst.Address.IP().To16()[:] 202 } 203 dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) 204 return dat 205 } 206 207 func (e netEndpoint) DstToString() string { 208 return e.dst.NetAddr() 209 } 210 211 func (e netEndpoint) SrcToString() string { 212 return "" 213 } 214 215 func toNetIpAddr(addr xnet.Address) netip.Addr { 216 if addr.Family().IsIPv4() { 217 ip := addr.IP() 218 return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) 219 } else { 220 ip := addr.IP() 221 arr := [16]byte{} 222 for i := 0; i < 16; i++ { 223 arr[i] = ip[i] 224 } 225 return netip.AddrFrom16(arr) 226 } 227 } 228 229 func stringsLastIndexByte(s string, b byte) int { 230 for i := len(s) - 1; i >= 0; i-- { 231 if s[i] == b { 232 return i 233 } 234 } 235 return -1 236 } 237 238 func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { 239 i := stringsLastIndexByte(s, ':') 240 if i == -1 { 241 return "", 0, false, errors.New("not an ip:port") 242 } 243 244 ip = s[:i] 245 portStr := s[i+1:] 246 if len(ip) == 0 { 247 return "", 0, false, errors.New("no IP") 248 } 249 if len(portStr) == 0 { 250 return "", 0, false, errors.New("no port") 251 } 252 port64, err := strconv.ParseUint(portStr, 10, 16) 253 if err != nil { 254 return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) 255 } 256 port = uint16(port64) 257 if ip[0] == '[' { 258 if len(ip) < 2 || ip[len(ip)-1] != ']' { 259 return "", 0, false, errors.New("missing ]") 260 } 261 ip = ip[1 : len(ip)-1] 262 v6 = true 263 } 264 265 return ip, port, v6, nil 266 }