github.com/xmplusdev/xray-core@v1.8.10/proxy/wireguard/client.go (about) 1 /* 2 3 Some of codes are copied from https://github.com/octeep/wireproxy, license below. 4 5 Copyright (c) 2022 Wind T.F. Wong <octeep@pm.me> 6 7 Permission to use, copy, modify, and distribute this software for any 8 purpose with or without fee is hereby granted, provided that the above 9 copyright notice and this permission notice appear in all copies. 10 11 THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 12 WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 13 MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 14 ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 15 WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 16 ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 17 OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 19 */ 20 21 package wireguard 22 23 import ( 24 "context" 25 "fmt" 26 "net/netip" 27 "strings" 28 "sync" 29 30 "github.com/xmplusdev/xray-core/common" 31 "github.com/xmplusdev/xray-core/common/buf" 32 "github.com/xmplusdev/xray-core/common/dice" 33 "github.com/xmplusdev/xray-core/common/log" 34 "github.com/xmplusdev/xray-core/common/net" 35 "github.com/xmplusdev/xray-core/common/protocol" 36 "github.com/xmplusdev/xray-core/common/session" 37 "github.com/xmplusdev/xray-core/common/signal" 38 "github.com/xmplusdev/xray-core/common/task" 39 "github.com/xmplusdev/xray-core/core" 40 "github.com/xmplusdev/xray-core/features/dns" 41 "github.com/xmplusdev/xray-core/features/policy" 42 "github.com/xmplusdev/xray-core/transport" 43 "github.com/xmplusdev/xray-core/transport/internet" 44 ) 45 46 // Handler is an outbound connection that silently swallow the entire payload. 47 type Handler struct { 48 conf *DeviceConfig 49 net Tunnel 50 bind *netBindClient 51 policyManager policy.Manager 52 dns dns.Client 53 // cached configuration 54 endpoints []netip.Addr 55 hasIPv4, hasIPv6 bool 56 wgLock sync.Mutex 57 } 58 59 // New creates a new wireguard handler. 60 func New(ctx context.Context, conf *DeviceConfig) (*Handler, error) { 61 v := core.MustFromContext(ctx) 62 63 endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf) 64 if err != nil { 65 return nil, err 66 } 67 68 d := v.GetFeature(dns.ClientType()).(dns.Client) 69 return &Handler{ 70 conf: conf, 71 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 72 dns: d, 73 endpoints: endpoints, 74 hasIPv4: hasIPv4, 75 hasIPv6: hasIPv6, 76 }, nil 77 } 78 79 func (h *Handler) processWireGuard(dialer internet.Dialer) (err error) { 80 h.wgLock.Lock() 81 defer h.wgLock.Unlock() 82 83 if h.bind != nil && h.bind.dialer == dialer && h.net != nil { 84 return nil 85 } 86 87 log.Record(&log.GeneralMessage{ 88 Severity: log.Severity_Info, 89 Content: "switching dialer", 90 }) 91 92 if h.net != nil { 93 _ = h.net.Close() 94 h.net = nil 95 } 96 if h.bind != nil { 97 _ = h.bind.Close() 98 h.bind = nil 99 } 100 101 // bind := conn.NewStdNetBind() // TODO: conn.Bind wrapper for dialer 102 bind := &netBindClient{ 103 netBind: netBind{ 104 dns: h.dns, 105 dnsOption: dns.IPOption{ 106 IPv4Enable: h.hasIPv4, 107 IPv6Enable: h.hasIPv6, 108 }, 109 workers: int(h.conf.NumWorkers), 110 }, 111 dialer: dialer, 112 reserved: h.conf.Reserved, 113 } 114 defer func() { 115 if err != nil { 116 _ = bind.Close() 117 } 118 }() 119 120 h.net, err = h.makeVirtualTun(bind) 121 if err != nil { 122 return newError("failed to create virtual tun interface").Base(err) 123 } 124 h.bind = bind 125 return nil 126 } 127 128 // Process implements OutboundHandler.Dispatch(). 129 func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 130 outbound := session.OutboundFromContext(ctx) 131 if outbound == nil || !outbound.Target.IsValid() { 132 return newError("target not specified") 133 } 134 outbound.Name = "wireguard" 135 inbound := session.InboundFromContext(ctx) 136 if inbound != nil { 137 inbound.SetCanSpliceCopy(3) 138 } 139 140 if err := h.processWireGuard(dialer); err != nil { 141 return err 142 } 143 144 // Destination of the inner request. 145 destination := outbound.Target 146 command := protocol.RequestCommandTCP 147 if destination.Network == net.Network_UDP { 148 command = protocol.RequestCommandUDP 149 } 150 151 // resolve dns 152 addr := destination.Address 153 if addr.Family().IsDomain() { 154 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 155 IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), 156 IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), 157 }) 158 { // Resolve fallback 159 if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { 160 ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ 161 IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), 162 IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), 163 }) 164 } 165 } 166 if err != nil { 167 return newError("failed to lookup DNS").Base(err) 168 } else if len(ips) == 0 { 169 return dns.ErrEmptyResponse 170 } 171 addr = net.IPAddress(ips[dice.Roll(len(ips))]) 172 } 173 174 var newCtx context.Context 175 var newCancel context.CancelFunc 176 if session.TimeoutOnlyFromContext(ctx) { 177 newCtx, newCancel = context.WithCancel(context.Background()) 178 } 179 180 p := h.policyManager.ForLevel(0) 181 182 ctx, cancel := context.WithCancel(ctx) 183 timer := signal.CancelAfterInactivity(ctx, func() { 184 cancel() 185 if newCancel != nil { 186 newCancel() 187 } 188 }, p.Timeouts.ConnectionIdle) 189 addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) 190 191 var requestFunc func() error 192 var responseFunc func() error 193 194 if command == protocol.RequestCommandTCP { 195 conn, err := h.net.DialContextTCPAddrPort(ctx, addrPort) 196 if err != nil { 197 return newError("failed to create TCP connection").Base(err) 198 } 199 defer conn.Close() 200 201 requestFunc = func() error { 202 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 203 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 204 } 205 responseFunc = func() error { 206 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 207 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 208 } 209 } else if command == protocol.RequestCommandUDP { 210 conn, err := h.net.DialUDPAddrPort(netip.AddrPort{}, addrPort) 211 if err != nil { 212 return newError("failed to create UDP connection").Base(err) 213 } 214 defer conn.Close() 215 216 requestFunc = func() error { 217 defer timer.SetTimeout(p.Timeouts.DownlinkOnly) 218 return buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)) 219 } 220 responseFunc = func() error { 221 defer timer.SetTimeout(p.Timeouts.UplinkOnly) 222 return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) 223 } 224 } 225 226 if newCtx != nil { 227 ctx = newCtx 228 } 229 230 responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) 231 if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { 232 common.Interrupt(link.Reader) 233 common.Interrupt(link.Writer) 234 return newError("connection ends").Base(err) 235 } 236 237 return nil 238 } 239 240 // creates a tun interface on netstack given a configuration 241 func (h *Handler) makeVirtualTun(bind *netBindClient) (Tunnel, error) { 242 t, err := h.conf.createTun()(h.endpoints, int(h.conf.Mtu), nil) 243 if err != nil { 244 return nil, err 245 } 246 247 bind.dnsOption.IPv4Enable = h.hasIPv4 248 bind.dnsOption.IPv6Enable = h.hasIPv6 249 250 if err = t.BuildDevice(h.createIPCRequest(bind, h.conf), bind); err != nil { 251 _ = t.Close() 252 return nil, err 253 } 254 return t, nil 255 } 256 257 // serialize the config into an IPC request 258 func (h *Handler) createIPCRequest(bind *netBindClient, conf *DeviceConfig) string { 259 var request strings.Builder 260 261 request.WriteString(fmt.Sprintf("private_key=%s\n", conf.SecretKey)) 262 263 if !conf.IsClient { 264 // placeholder, we'll handle actual port listening on Xray 265 request.WriteString("listen_port=1337\n") 266 } 267 268 for _, peer := range conf.Peers { 269 if peer.PublicKey != "" { 270 request.WriteString(fmt.Sprintf("public_key=%s\n", peer.PublicKey)) 271 } 272 273 if peer.PreSharedKey != "" { 274 request.WriteString(fmt.Sprintf("preshared_key=%s\n", peer.PreSharedKey)) 275 } 276 277 address, port, err := net.SplitHostPort(peer.Endpoint) 278 if err != nil { 279 newError("failed to split endpoint ", peer.Endpoint, " into address and port").AtError().WriteToLog() 280 } 281 addr := net.ParseAddress(address) 282 if addr.Family().IsDomain() { 283 dialerIp := bind.dialer.DestIpAddress() 284 if dialerIp != nil { 285 addr = net.ParseAddress(dialerIp.String()) 286 newError("createIPCRequest use dialer dest ip: ", addr).WriteToLog() 287 } else { 288 ips, err := h.dns.LookupIP(addr.Domain(), dns.IPOption{ 289 IPv4Enable: h.hasIPv4 && h.conf.preferIP4(), 290 IPv6Enable: h.hasIPv6 && h.conf.preferIP6(), 291 }) 292 { // Resolve fallback 293 if (len(ips) == 0 || err != nil) && h.conf.hasFallback() { 294 ips, err = h.dns.LookupIP(addr.Domain(), dns.IPOption{ 295 IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(), 296 IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(), 297 }) 298 } 299 } 300 if err != nil { 301 newError("createIPCRequest failed to lookup DNS").Base(err).WriteToLog() 302 } else if len(ips) == 0 { 303 newError("createIPCRequest empty lookup DNS").WriteToLog() 304 } else { 305 addr = net.IPAddress(ips[dice.Roll(len(ips))]) 306 } 307 } 308 } 309 310 if peer.Endpoint != "" { 311 request.WriteString(fmt.Sprintf("endpoint=%s:%s\n", addr, port)) 312 } 313 314 for _, ip := range peer.AllowedIps { 315 request.WriteString(fmt.Sprintf("allowed_ip=%s\n", ip)) 316 } 317 318 if peer.KeepAlive != 0 { 319 request.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", peer.KeepAlive)) 320 } 321 } 322 323 return request.String()[:request.Len()] 324 }