github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/client.go (about) 1 /* 2 3 Some of codes are copied from https://github.com/octeep/wireproxy, license below. 4 5 Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> 6 7 Permission to use, copy, modify, and distribute this software for any 8 purpose with or without fee is hereby granted, provided that the above 9 copyright notice and this permission notice appear in all copies. 10 11 THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 12 WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 13 MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 14 ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 15 WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 16 ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 17 OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 19 */ 20 21 package wireguard 22 23 import ( 24 "context" 25 "fmt" 26 "net/netip" 27 "strings" 28 "sync" 29 30 "github.com/xtls/xray-core/common" 31 "github.com/xtls/xray-core/common/buf" 32 "github.com/xtls/xray-core/common/dice" 33 "github.com/xtls/xray-core/common/log" 34 "github.com/xtls/xray-core/common/net" 35 "github.com/xtls/xray-core/common/protocol" 36 "github.com/xtls/xray-core/common/session" 37 "github.com/xtls/xray-core/common/signal" 38 "github.com/xtls/xray-core/common/task" 39 "github.com/xtls/xray-core/core" 40 "github.com/xtls/xray-core/features/dns" 41 "github.com/xtls/xray-core/features/policy" 42 "github.com/xtls/xray-core/transport" 43 "github.com/xtls/xray-core/transport/internet" 44 ) 45 46 // Handler is an outbound connection that silently swallow the entire payload. 47 type Handler struct { 48 conf *DeviceConfig 49 net Tunnel 50 bind *netBindClient 51 policyManager policy.Manager 52 dns dns.Client 53 // cached configuration 54 endpoints []netip.Addr 55 hasIPv4, hasIPv6 bool 56 wgLock sync.Mutex 57 } 58 59 // New creates a new wireguard handler. 60 func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { 61 v := core.MustFromContext(ctx) 62 63 endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) 64 if err != nil { 65 return nil, err 66 } 67 68 d := v.GetFeature(dns.ClientType()).(dns.Client) 69 return &Handler{ 70 conf: conf, 71 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 72 dns: d, 73 endpoints: endpoints, 74 hasIPv4: hasIPv4, 75 hasIPv6: hasIPv6, 76 }, nil 77 } 78 79 func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { 80 h.wgLock.Lock() 81 defer h.wgLock.Unlock() 82 83 if h.bind != nil && h.bind.dialer == dialer && h.net != nil { 84 return nil 85 } 86 87 log.Record(&log.GeneralMessage{ 88 Severity: log.Severity_Info, 89 Content: "switching dialer", 90 }) 91 92 if h.net != nil { 93 _ = h.net.Close() 94 h.net = nil 95 } 96 if h.bind != nil { 97 _ = h.bind.Close() 98 h.bind = nil 99 } 100 101 // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer 102 bind := &netBindClient{ 103 netBind: netBind{ 104 dns: h.dns, 105 dnsOption: dns.IPOption{ 106 IPv4Enable: h.hasIPv4, 107 IPv6Enable: h.hasIPv6, 108 }, 109 workers: int(h.conf.NumWorkers), 110 }, 111 dialer: dialer, 112 reserved: h.conf.Reserved, 113 } 114 defer func() { 115 if err != nil { 116 _ = bind.Close() 117 } 118 }() 119 120 h.net, err = h.makeVirtualTun(bind) 121 if err != nil { 122 return newError("failed to create virtual tun interface").Base(err) 123 } 124 h.bind = bind 125 return nil 126 } 127 128 // Process implements OutboundHandler.Dispatch(). 129 func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 130 outbounds := session.OutboundsFromContext(ctx) 131 ob := outbounds[len(outbounds) - 1] 132 if !ob.Target.IsValid() { 133 return newError("target not specified") 134 } 135 ob.Name = "wireguard" 136 ob.CanSpliceCopy = 3 137 138 if err := h.processWireGuard(dialer); err != nil { 139 return err 140 } 141 142 // Destination of the inner request. 143 destination := ob.Target 144 command := protocol.RequestCommandTCP 145 if destination.Network == net.Network_UDP { 146 command = protocol.RequestCommandUDP 147 } 148 149 // resolve dns 150 addr := destination.Address 151 if addr.Family().IsDomain() { 152 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 153 IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), 154 IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), 155 }) 156 { // Resolve fallback 157 if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { 158 ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ 159 IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), 160 IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), 161 }) 162 } 163 } 164 if err != nil { 165 return newError("failed to lookup DNS").Base(err) 166 } else if len(ips) == 0 { 167 return dns.ErrEmptyResponse 168 } 169 addr = net.IPAddress(ips[dice.Roll(len(ips))]) 170 } 171 172 var newCtx context.Context 173 var newCancel context.CancelFunc 174 if session.TimeoutOnlyFromContext(ctx) { 175 newCtx, newCancel = context.WithCancel(context.Background()) 176 } 177 178 p := h.policyManager.ForLevel(0) 179 180 ctx, cancel := context.WithCancel(ctx) 181 timer := signal.CancelAfterInactivity(ctx, func() { 182 cancel() 183 if newCancel != nil { 184 newCancel() 185 } 186 }, p.Timeouts.ConnectionIdle) 187 addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) 188 189 var requestFunc func() error 190 var responseFunc func() error 191 192 if command == protocol.RequestCommandTCP { 193 conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) 194 if err != nil { 195 return newError("failed to create TCP connection").Base(err) 196 } 197 defer conn.Close() 198 199 requestFunc = func() error { 200 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 201 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 202 } 203 responseFunc = func() error { 204 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 205 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 206 } 207 } else if command == protocol.RequestCommandUDP { 208 conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) 209 if err != nil { 210 return newError("failed to create UDP connection").Base(err) 211 } 212 defer conn.Close() 213 214 requestFunc = func() error { 215 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 216 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 217 } 218 responseFunc = func() error { 219 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 220 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 221 } 222 } 223 224 if newCtx != nil { 225 ctx = newCtx 226 } 227 228 responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) 229 if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { 230 common.Interrupt(link.Reader) 231 common.Interrupt(link.Writer) 232 return newError("connection ends").Base(err) 233 } 234 235 return nil 236 } 237 238 // creates a tun interface on netstack given a configuration 239 func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { 240 t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil) 241 if err != nil { 242 return nil, err 243 } 244 245 bind.dnsOption.IPv4Enable = h.hasIPv4 246 bind.dnsOption.IPv6Enable = h.hasIPv6 247 248 if err = t.BuildDevice(h.createIPCRequest(bind, h.conf), bind); err != nil { 249 _ = t.Close() 250 return nil, err 251 } 252 return t, nil 253 } 254 255 // serialize the config into an IPC request 256 func (h *Handler) createIPCRequest(bind *netBindClient, conf *DeviceConfig) string { 257 var request strings.Builder 258 259 request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) 260 261 if !conf.IsClient { 262 // placeholder, we'll handle actual port listening on Xray 263 request.WriteString("listen_port=1337\n") 264 } 265 266 for _, peer := range conf.Peers { 267 if peer.PublicKey != "" { 268 request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) 269 } 270 271 if peer.PreSharedKey != "" { 272 request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) 273 } 274 275 address, port, err := net.SplitHostPort(peer.Endpoint) 276 if err != nil { 277 newError("failed to split endpoint ", peer.Endpoint, " into address and port").AtError().WriteToLog() 278 } 279 addr := net.ParseAddress(address) 280 if addr.Family().IsDomain() { 281 dialerIp := bind.dialer.DestIpAddress() 282 if dialerIp != nil { 283 addr = net.ParseAddress(dialerIp.String()) 284 newError("createIPCRequest use dialer dest ip: ", addr).WriteToLog() 285 } else { 286 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 287 IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), 288 IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), 289 }) 290 { // Resolve fallback 291 if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { 292 ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ 293 IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), 294 IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), 295 }) 296 } 297 } 298 if err != nil { 299 newError("createIPCRequest failed to lookup DNS").Base(err).WriteToLog() 300 } else if len(ips) == 0 { 301 newError("createIPCRequest empty lookup DNS").WriteToLog() 302 } else { 303 addr = net.IPAddress(ips[dice.Roll(len(ips))]) 304 } 305 } 306 } 307 308 if peer.Endpoint != "" { 309 request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port)) 310 } 311 312 for _, ip := range peer.AllowedIps { 313 request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) 314 } 315 316 if peer.KeepAlive != 0 { 317 request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) 318 } 319 } 320 321 return request.String()[:request.Len()] 322 }