github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "net/netip" 9 "strconv" 10 "sync" 11 12 "github.com/sagernet/wireguard-go/conn" 13 xnet "github.com/moqsien/xraycore/common/net" 14 "github.com/moqsien/xraycore/features/dns" 15 "github.com/moqsien/xraycore/transport/internet" 16 ) 17 18 type netReadInfo struct { 19 // status 20 waiter sync.WaitGroup 21 // param 22 buff [][]byte 23 // result 24 bytes int 25 endpoint conn.Endpoint 26 err error 27 } 28 29 type netBindClient struct { 30 workers int 31 dialer internet.Dialer 32 dns dns.Client 33 dnsOption dns.IPOption 34 reserved []byte 35 36 readQueue chan *netReadInfo 37 } 38 39 func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) { 40 ipStr, port, _, err := splitAddrPort(s) 41 if err != nil { 42 return nil, err 43 } 44 45 var addr net.IP 46 if IsDomainName(ipStr) { 47 ips, err := n.dns.LookupIP(ipStr, n.dnsOption) 48 if err != nil { 49 return nil, err 50 } else if len(ips) == 0 { 51 return nil, dns.ErrEmptyResponse 52 } 53 addr = ips[0] 54 } else { 55 addr = net.ParseIP(ipStr) 56 } 57 if addr == nil { 58 return nil, errors.New("failed to parse ip: " + ipStr) 59 } 60 61 var ip xnet.Address 62 if p4 := addr.To4(); len(p4) == net.IPv4len { 63 ip = xnet.IPAddress(p4[:]) 64 } else { 65 ip = xnet.IPAddress(addr[:]) 66 } 67 68 dst := xnet.Destination{ 69 Address: ip, 70 Port: xnet.Port(port), 71 Network: xnet.Network_UDP, 72 } 73 74 return &netEndpoint{ 75 dst: dst, 76 }, nil 77 } 78 79 func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { 80 bind.readQueue = make(chan *netReadInfo) 81 82 fun := func(buff [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) { 83 defer func() { 84 if r := recover(); r != nil { 85 count = 0 86 err = errors.New("channel closed") 87 } 88 }() 89 90 r := &netReadInfo{ 91 buff: buff, 92 } 93 r.waiter.Add(1) 94 bind.readQueue <- r 95 r.waiter.Wait() // wait read goroutine done, or we will miss the result 96 return r.bytes, r.err 97 } 98 workers := bind.workers 99 if workers <= 0 { 100 workers = 1 101 } 102 arr := make([]conn.ReceiveFunc, workers) 103 for i := 0; i < workers; i++ { 104 arr[i] = fun 105 } 106 107 return arr, uint16(uport), nil 108 } 109 110 func (bind *netBindClient) Close() error { 111 if bind.readQueue != nil { 112 close(bind.readQueue) 113 } 114 return nil 115 } 116 117 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { 118 c, err := bind.dialer.Dial(context.Background(), endpoint.dst) 119 if err != nil { 120 return err 121 } 122 endpoint.conn = c 123 124 go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { 125 for { 126 v, ok := <-readQueue 127 if !ok { 128 return 129 } 130 b := v.buff[0] 131 i, err := c.Read(b) 132 133 if i > 3 { 134 b[1] = 0 135 b[2] = 0 136 b[3] = 0 137 } 138 139 v.bytes = i 140 v.endpoint = endpoint 141 v.err = err 142 v.waiter.Done() 143 if err != nil && errors.Is(err, io.EOF) { 144 endpoint.conn = nil 145 return 146 } 147 } 148 }(bind.readQueue, endpoint) 149 150 return nil 151 } 152 153 func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { 154 var err error 155 156 nend, ok := endpoint.(*netEndpoint) 157 if !ok { 158 return conn.ErrWrongEndpointType 159 } 160 161 if nend.conn == nil { 162 err = bind.connectTo(nend) 163 if err != nil { 164 return err 165 } 166 } 167 168 if len(buff[0]) > 3 && len(bind.reserved) == 3 { 169 copy(buff[0][1:], bind.reserved) 170 } 171 172 _, err = nend.conn.Write(buff[0]) 173 174 return err 175 } 176 177 func (bind *netBindClient) SetMark(mark uint32) error { 178 return nil 179 } 180 181 func (bind *netBindClient) BatchSize() int { 182 return 1 183 } 184 185 type netEndpoint struct { 186 dst xnet.Destination 187 conn net.Conn 188 } 189 190 func (netEndpoint) ClearSrc() {} 191 192 func (e netEndpoint) DstIP() netip.Addr { 193 return toNetIpAddr(e.dst.Address) 194 } 195 196 func (e netEndpoint) SrcIP() netip.Addr { 197 return netip.Addr{} 198 } 199 200 func (e netEndpoint) DstToBytes() []byte { 201 var dat []byte 202 if e.dst.Address.Family().IsIPv4() { 203 dat = e.dst.Address.IP().To4()[:] 204 } else { 205 dat = e.dst.Address.IP().To16()[:] 206 } 207 dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) 208 return dat 209 } 210 211 func (e netEndpoint) DstToString() string { 212 return e.dst.NetAddr() 213 } 214 215 func (e netEndpoint) SrcToString() string { 216 return "" 217 } 218 219 func toNetIpAddr(addr xnet.Address) netip.Addr { 220 if addr.Family().IsIPv4() { 221 ip := addr.IP() 222 return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) 223 } else { 224 ip := addr.IP() 225 arr := [16]byte{} 226 for i := 0; i < 16; i++ { 227 arr[i] = ip[i] 228 } 229 return netip.AddrFrom16(arr) 230 } 231 } 232 233 func stringsLastIndexByte(s string, b byte) int { 234 for i := len(s) - 1; i >= 0; i-- { 235 if s[i] == b { 236 return i 237 } 238 } 239 return -1 240 } 241 242 func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) { 243 i := stringsLastIndexByte(s, ':') 244 if i == -1 { 245 return "", 0, false, errors.New("not an ip:port") 246 } 247 248 ip = s[:i] 249 portStr := s[i+1:] 250 if len(ip) == 0 { 251 return "", 0, false, errors.New("no IP") 252 } 253 if len(portStr) == 0 { 254 return "", 0, false, errors.New("no port") 255 } 256 port64, err := strconv.ParseUint(portStr, 10, 16) 257 if err != nil { 258 return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s)) 259 } 260 port = uint16(port64) 261 if ip[0] == '[' { 262 if len(ip) < 2 || ip[len(ip)-1] != ']' { 263 return "", 0, false, errors.New("missing ]") 264 } 265 ip = ip[1 : len(ip)-1] 266 v6 = true 267 } 268 269 return ip, port, v6, nil 270 }