github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "net" 8 "net/netip" 9 "strconv" 10 "sync" 11 12 "golang.zx2c4.com/wireguard/conn" 13 14 xnet "github.com/xmplusdev/xray-core/common/net" 15 "github.com/xmplusdev/xray-core/features/dns" 16 "github.com/xmplusdev/xray-core/transport/internet" 17 ) 18 19 type netReadInfo struct { 20 // status 21 waiter sync.WaitGroup 22 // param 23 buff []byte 24 // result 25 bytes int 26 endpoint conn.Endpoint 27 err error 28 } 29 30 // reduce duplicated code 31 type netBind struct { 32 dns dns.Client 33 dnsOption dns.IPOption 34 35 workers int 36 readQueue chan *netReadInfo 37 } 38 39 // SetMark implements conn.Bind 40 func (bind *netBind) SetMark(mark uint32) error { 41 return nil 42 } 43 44 // ParseEndpoint implements conn.Bind 45 func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) { 46 ipStr, port, err := net.SplitHostPort(s) 47 if err != nil { 48 return nil, err 49 } 50 portNum, err := strconv.Atoi(port) 51 if err != nil { 52 return nil, err 53 } 54 55 addr := xnet.ParseAddress(ipStr) 56 if addr.Family() == xnet.AddressFamilyDomain { 57 ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption) 58 if err != nil { 59 return nil, err 60 } else if len(ips) == 0 { 61 return nil, dns.ErrEmptyResponse 62 } 63 addr = xnet.IPAddress(ips[0]) 64 } 65 66 dst := xnet.Destination{ 67 Address: addr, 68 Port: xnet.Port(portNum), 69 Network: xnet.Network_UDP, 70 } 71 72 return &netEndpoint{ 73 dst: dst, 74 }, nil 75 } 76 77 // BatchSize implements conn.Bind 78 func (bind *netBind) BatchSize() int { 79 return 1 80 } 81 82 // Open implements conn.Bind 83 func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { 84 bind.readQueue = make(chan *netReadInfo) 85 86 fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { 87 defer func() { 88 if r := recover(); r != nil { 89 n = 0 90 err = errors.New("channel closed") 91 } 92 }() 93 94 r := &netReadInfo{ 95 buff: bufs[0], 96 } 97 r.waiter.Add(1) 98 bind.readQueue <- r 99 r.waiter.Wait() // wait read goroutine done, or we will miss the result 100 sizes[0], eps[0] = r.bytes, r.endpoint 101 return 1, r.err 102 } 103 workers := bind.workers 104 if workers <= 0 { 105 workers = 1 106 } 107 arr := make([]conn.ReceiveFunc, workers) 108 for i := 0; i < workers; i++ { 109 arr[i] = fun 110 } 111 112 return arr, uint16(uport), nil 113 } 114 115 // Close implements conn.Bind 116 func (bind *netBind) Close() error { 117 if bind.readQueue != nil { 118 close(bind.readQueue) 119 } 120 return nil 121 } 122 123 type netBindClient struct { 124 netBind 125 126 dialer internet.Dialer 127 reserved []byte 128 } 129 130 func (bind *netBindClient) connectTo(endpoint *netEndpoint) error { 131 c, err := bind.dialer.Dial(context.Background(), endpoint.dst) 132 if err != nil { 133 return err 134 } 135 endpoint.conn = c 136 137 go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) { 138 for { 139 v, ok := <-readQueue 140 if !ok { 141 return 142 } 143 i, err := c.Read(v.buff) 144 145 if i > 3 { 146 v.buff[1] = 0 147 v.buff[2] = 0 148 v.buff[3] = 0 149 } 150 151 v.bytes = i 152 v.endpoint = endpoint 153 v.err = err 154 v.waiter.Done() 155 if err != nil && errors.Is(err, io.EOF) { 156 endpoint.conn = nil 157 return 158 } 159 } 160 }(bind.readQueue, endpoint) 161 162 return nil 163 } 164 165 func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error { 166 var err error 167 168 nend, ok := endpoint.(*netEndpoint) 169 if !ok { 170 return conn.ErrWrongEndpointType 171 } 172 173 if nend.conn == nil { 174 err = bind.connectTo(nend) 175 if err != nil { 176 return err 177 } 178 } 179 180 for _, buff := range buff { 181 if len(buff) > 3 && len(bind.reserved) == 3 { 182 copy(buff[1:], bind.reserved) 183 } 184 if _, err = nend.conn.Write(buff); err != nil { 185 return err 186 } 187 } 188 return nil 189 } 190 191 type netBindServer struct { 192 netBind 193 } 194 195 func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error { 196 var err error 197 198 nend, ok := endpoint.(*netEndpoint) 199 if !ok { 200 return conn.ErrWrongEndpointType 201 } 202 203 if nend.conn == nil { 204 return newError("connection not open yet") 205 } 206 207 for _, buff := range buff { 208 if _, err = nend.conn.Write(buff); err != nil { 209 return err 210 } 211 } 212 213 return err 214 } 215 216 type netEndpoint struct { 217 dst xnet.Destination 218 conn net.Conn 219 } 220 221 func (netEndpoint) ClearSrc() {} 222 223 func (e netEndpoint) DstIP() netip.Addr { 224 return netip.Addr{} 225 } 226 227 func (e netEndpoint) SrcIP() netip.Addr { 228 return netip.Addr{} 229 } 230 231 func (e netEndpoint) DstToBytes() []byte { 232 var dat []byte 233 if e.dst.Address.Family().IsIPv4() { 234 dat = e.dst.Address.IP().To4()[:] 235 } else { 236 dat = e.dst.Address.IP().To16()[:] 237 } 238 dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8)) 239 return dat 240 } 241 242 func (e netEndpoint) DstToString() string { 243 return e.dst.NetAddr() 244 } 245 246 func (e netEndpoint) SrcToString() string { 247 return "" 248 } 249 250 func toNetIpAddr(addr xnet.Address) netip.Addr { 251 if addr.Family().IsIPv4() { 252 ip := addr.IP() 253 return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]}) 254 } else { 255 ip := addr.IP() 256 arr := [16]byte{} 257 for i := 0; i < 16; i++ { 258 arr[i] = ip[i] 259 } 260 return netip.AddrFrom16(arr) 261 } 262 }