github.com/yaling888/clash@v1.53.0/adapter/outbound/wireguard_gvsior.go (about) 1 //go:build !nogvisor 2 3 package outbound 4 5 import ( 6 "context" 7 "encoding/base64" 8 "encoding/hex" 9 "errors" 10 "fmt" 11 "math/rand/v2" 12 "net" 13 "net/netip" 14 "strconv" 15 "strings" 16 "sync" 17 "syscall" 18 "time" 19 _ "unsafe" 20 21 "github.com/phuslu/log" 22 "github.com/samber/lo" 23 bind "golang.zx2c4.com/wireguard/conn" 24 "golang.zx2c4.com/wireguard/device" 25 "golang.zx2c4.com/wireguard/tun" 26 27 "github.com/yaling888/clash/component/dialer" 28 "github.com/yaling888/clash/component/iface" 29 "github.com/yaling888/clash/component/resolver" 30 C "github.com/yaling888/clash/constant" 31 "github.com/yaling888/clash/transport/wireguard" 32 ) 33 34 //go:linkname controlFns golang.zx2c4.com/wireguard/conn.controlFns 35 var controlFns []func(network, address string, c syscall.RawConn) error 36 37 const dialTimeout = 10 * time.Second 38 39 var _ C.ProxyAdapter = (*WireGuard)(nil) 40 41 type WireGuard struct { 42 *Base 43 wgDevice *device.Device 44 tunDevice tun.Device 45 netStack *wireguard.Net 46 bind bind.Bind 47 48 localIP netip.Addr 49 localIPv6 netip.Addr 50 dnsServers []netip.Addr 51 reserved []byte 52 uapiConf []string 53 threadId string 54 mtu int 55 hasV6 bool 56 57 upOnce sync.Once 58 downOnce sync.Once 59 upErr error 60 61 remoteDnsResolve bool 62 } 63 64 // DialContext implements C.ProxyAdapter 65 func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) { 66 w.up() 67 if w.upErr != nil { 68 return nil, fmt.Errorf("apply wireguard proxy %s config error: %w", w.threadId, w.upErr) 69 } 70 71 dialCtx := ctx 72 if _, hasDeadline := ctx.Deadline(); !hasDeadline { 73 var cancel context.CancelFunc 74 dialCtx, cancel = context.WithDeadline(ctx, time.Now().Add(dialTimeout)) 75 defer cancel() 76 } 77 78 if err := w.resolveDNS(metadata, false); err != nil { 79 return nil, fmt.Errorf("resolve DNS failed: %w", err) 80 } 81 82 c, err := w.netStack.DialContextTCPAddrPort(dialCtx, netip.AddrPortFrom(metadata.DstIP, uint16(metadata.DstPort))) 83 if err != nil { 84 return nil, err 85 } 86 if c == nil { 87 return nil, errors.New("conn is nil") 88 } 89 return NewConn(&wgConn{c}, w), nil 90 } 91 92 // ListenPacketContext implements C.ProxyAdapter 93 func (w *WireGuard) ListenPacketContext(_ context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.PacketConn, error) { 94 w.up() 95 if w.upErr != nil { 96 return nil, fmt.Errorf("apply wireguard proxy %s config failure, cause: %w", w.threadId, w.upErr) 97 } 98 99 if err := w.resolveDNS(metadata, true); err != nil { 100 return nil, fmt.Errorf("resolve DNS failed: %w", err) 101 } 102 103 var lAddr netip.Addr 104 if metadata.DstIP.Is6() { 105 lAddr = w.localIPv6 106 } else { 107 lAddr = w.localIP 108 } 109 110 pc, err := w.netStack.ListenUDPAddrPort(netip.AddrPortFrom(lAddr, 0)) 111 if err != nil { 112 return nil, err 113 } 114 if pc == nil { 115 return nil, errors.New("packetConn is nil") 116 } 117 return NewPacketConn(&wgPConn{pc}, w), nil 118 } 119 120 // Cleanup implements C.Cleanup 121 func (w *WireGuard) Cleanup() { 122 w.downOnce.Do(func() { 123 if w.wgDevice != nil { 124 w.wgDevice.Close() 125 } 126 }) 127 } 128 129 // DisableDnsResolve implements C.DisableDnsResolve 130 func (w *WireGuard) DisableDnsResolve() bool { 131 return true // let WireGuard resolve it 132 } 133 134 func (w *WireGuard) UpdateBind() { 135 if w.bind == nil || w.wgDevice == nil { 136 return 137 } 138 if s, ok := w.bind.(*wireguard.StdNetBind); ok { 139 s.UpdateControlFns(getBindControlFns(w.Base.name)) 140 } 141 142 _ = w.wgDevice.BindUpdate() 143 _ = w.bindSocketToInterface() 144 } 145 146 // bindSocketToInterface used by WinRingBind 147 func (w *WireGuard) bindSocketToInterface() error { 148 if b, ok := w.bind.(bind.BindSocketToInterface); ok { 149 interfaceName := getInterfaceName(w.Base.iface) 150 if interfaceName == "" { 151 return nil 152 } 153 obj, err := iface.ResolveInterface(interfaceName) 154 if err != nil { 155 return err 156 } 157 _ = b.BindSocketToInterface4(uint32(obj.Index), false) 158 _ = b.BindSocketToInterface6(uint32(obj.Index), false) 159 } 160 return nil 161 } 162 163 func (w *WireGuard) resolveDNS(metadata *C.Metadata, udp bool) error { 164 if metadata.Host == "" { 165 return nil 166 } 167 if w.remoteDnsResolve { 168 var ( 169 rAddrs []netip.Addr 170 err error 171 ) 172 if w.hasV6 { 173 rAddrs, err = resolver.LookupIPByProxy(context.Background(), metadata.Host, w.name) 174 } else { 175 rAddrs, err = resolver.LookupIPv4ByProxy(context.Background(), metadata.Host, w.name) 176 } 177 if err != nil { 178 return err 179 } 180 if udp { 181 metadata.DstIP = rAddrs[0] 182 } else { 183 if w.hasV6 { 184 v6 := lo.Filter(rAddrs, func(addr netip.Addr, _ int) bool { 185 return addr.Is6() 186 }) 187 if len(v6) > 0 { 188 rAddrs = v6 189 } 190 } 191 metadata.DstIP = rAddrs[rand.IntN(len(rAddrs))] 192 } 193 } else if !metadata.Resolved() { 194 var ( 195 rAddrs []netip.Addr 196 err error 197 ) 198 if w.hasV6 { 199 rAddrs, err = resolver.LookupIP(context.Background(), metadata.Host) 200 } else { 201 rAddrs, err = resolver.LookupIPv4(context.Background(), metadata.Host) 202 } 203 if err != nil { 204 return err 205 } 206 if udp { 207 metadata.DstIP = rAddrs[0] 208 } else { 209 metadata.DstIP = rAddrs[rand.IntN(len(rAddrs))] 210 } 211 } 212 return nil 213 } 214 215 func (w *WireGuard) up() { 216 w.upOnce.Do(func() { 217 w.upErr = w.init() 218 }) 219 } 220 221 func (w *WireGuard) init() error { 222 host, port, _ := net.SplitHostPort(w.Base.Addr()) 223 tryTimes := 0 224 225 lookup: 226 endpointIP, err := resolver.ResolveProxyServerHost(host) 227 if err != nil { 228 if tryTimes < 5 { 229 tryTimes++ 230 time.Sleep(2 * time.Second) 231 goto lookup 232 } 233 return fmt.Errorf("parse server endpoint [%s] failure, cause: %w", w.Base.Addr(), err) 234 } 235 236 p, _ := strconv.ParseUint(port, 10, 16) 237 endpoint := netip.AddrPortFrom(endpointIP, uint16(p)) 238 w.uapiConf = append(w.uapiConf, fmt.Sprintf("endpoint=%s", endpoint)) 239 240 localIPs := make([]netip.Addr, 0, 2) 241 if w.localIP.IsValid() { 242 localIPs = append(localIPs, w.localIP) 243 } 244 if w.localIPv6.IsValid() { 245 w.hasV6 = true 246 localIPs = append(localIPs, w.localIPv6) 247 } 248 249 tunDevice, netStack, err := wireguard.CreateNetTUN(localIPs, w.dnsServers, w.mtu) 250 if err != nil { 251 return err 252 } 253 254 wgBind := wireguard.NewDefaultBind(getBindControlFns(w.Base.iface), w.Base.iface, w.reserved) 255 w.bind = wgBind 256 257 logger := &device.Logger{ 258 Verbosef: func(format string, args ...any) { 259 log.Debug().Msgf("[WireGuard] [%s] "+strings.ToLower(format), append([]any{w.threadId}, args...)...) 260 }, 261 Errorf: func(format string, args ...any) { 262 log.Error().Msgf("[WireGuard] [%s] "+strings.ToLower(format), append([]any{w.threadId}, args...)...) 263 }, 264 } 265 266 wgDevice := device.NewDevice(tunDevice, wgBind, logger) 267 268 log.Debug().Strs("config", w.uapiConf).Msgf("[WireGuard] initial wireguard proxy %s", w.threadId) 269 270 err = wgDevice.IpcSet(strings.Join(w.uapiConf, "\n")) 271 if err != nil { 272 return err 273 } 274 275 _ = w.bindSocketToInterface() 276 277 w.tunDevice = tunDevice 278 w.netStack = netStack 279 w.wgDevice = wgDevice 280 w.uapiConf = nil 281 w.dnsServers = nil 282 w.reserved = nil 283 return nil 284 } 285 286 func NewWireGuard(option WireGuardOption) (*WireGuard, error) { 287 uapiConf := make([]string, 0, 6) 288 privateKeyBytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) 289 if err != nil { 290 return nil, fmt.Errorf("decode wireguard private key failure, cause: %w", err) 291 } 292 uapiConf = append(uapiConf, fmt.Sprintf("private_key=%s", hex.EncodeToString(privateKeyBytes))) 293 294 publicKeyBytes, err := base64.StdEncoding.DecodeString(option.PublicKey) 295 if err != nil { 296 return nil, fmt.Errorf("decode wireguard peer public key failure, cause: %w", err) 297 } 298 uapiConf = append(uapiConf, fmt.Sprintf("public_key=%s", hex.EncodeToString(publicKeyBytes))) 299 300 if option.PresharedKey != "" { 301 bytes, err := base64.StdEncoding.DecodeString(option.PresharedKey) 302 if err != nil { 303 return nil, fmt.Errorf("decode wireguard preshared key failure, cause: %w", err) 304 } 305 uapiConf = append(uapiConf, fmt.Sprintf("preshared_key=%s", hex.EncodeToString(bytes))) 306 } 307 308 var reservedBytes []byte 309 if option.Reserved != "" { 310 reserved := strings.TrimPrefix(strings.ToLower(option.Reserved), "0x") 311 if reservedBytes, err = hex.DecodeString(reserved); err != nil || len(reservedBytes) != 3 { 312 return nil, fmt.Errorf("decode wireguard reserved 3 bytes failure %w", err) 313 } 314 } 315 316 var ( 317 localIP netip.Addr 318 localIPv6 netip.Addr 319 ) 320 if option.IP != "" { 321 option.IP, _, _ = strings.Cut(option.IP, "/") 322 if localIP, err = netip.ParseAddr(option.IP); err != nil { 323 return nil, fmt.Errorf("parse wireguard ip address failure, cause: %w", err) 324 } 325 } 326 327 if option.IPv6 != "" { 328 option.IPv6, _, _ = strings.Cut(option.IPv6, "/") 329 if localIPv6, err = netip.ParseAddr(option.IPv6); err != nil { 330 return nil, fmt.Errorf("parse wireguard ipv6 address failure, cause: %w", err) 331 } 332 } 333 334 if !localIP.IsValid() && !localIPv6.IsValid() { 335 return nil, errors.New("wireguard missing local ip") 336 } 337 338 dns := option.DNS 339 if len(dns) == 0 { 340 dns = append(dns, "1.1.1.1", "8.8.8.8") 341 } 342 dnsServers := make([]netip.Addr, len(dns)) 343 for _, d := range dns { 344 if ip, err1 := netip.ParseAddr(d); err1 != nil { 345 return nil, fmt.Errorf("parse wireguard dns address failure, cause: %w", err1) 346 } else { 347 dnsServers = append(dnsServers, ip) 348 } 349 } 350 351 if localIP.IsValid() { 352 uapiConf = append(uapiConf, "allowed_ip=0.0.0.0/0") 353 } 354 if localIPv6.IsValid() { 355 uapiConf = append(uapiConf, "allowed_ip=::/0") 356 } 357 358 mtu := option.MTU 359 if mtu == 0 { 360 mtu = 1408 361 } 362 363 threadId := fmt.Sprintf("%s-%d", option.Name, rand.IntN(100)) 364 365 base := &Base{ 366 name: option.Name, 367 addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), 368 tp: C.WireGuard, 369 udp: option.UDP, 370 iface: option.Interface, 371 rmark: option.RoutingMark, 372 } 373 wireGuard := &WireGuard{ 374 Base: base, 375 localIP: localIP, 376 localIPv6: localIPv6, 377 dnsServers: dnsServers, 378 reserved: reservedBytes, 379 uapiConf: uapiConf, 380 threadId: threadId, 381 mtu: mtu, 382 383 remoteDnsResolve: option.RemoteDnsResolve, 384 } 385 return wireGuard, nil 386 } 387 388 // getBindControlFns used by StdNetBind 389 func getBindControlFns(interfaceName string) []func(network, address string, c syscall.RawConn) error { 390 var bindFns []func(network, address string, c syscall.RawConn) error 391 392 bindFns = append(bindFns, controlFns...) 393 bindFns = append(bindFns, dialer.WithBindToInterfaceControlFn(getInterfaceName(interfaceName))) 394 395 return bindFns 396 } 397 398 func getInterfaceName(interfaceName string) string { 399 if interfaceName == "" { 400 interfaceName = dialer.DefaultInterface.Load() 401 } 402 return interfaceName 403 } 404 405 type wgConn struct { 406 net.Conn 407 } 408 409 type wgPConn struct { 410 net.PacketConn 411 }