github.com/metacubex/mihomo@v1.18.5/adapter/outbound/wireguard.go (about) 1 package outbound 2 3 import ( 4 "context" 5 "encoding/base64" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 "net" 10 "net/netip" 11 "runtime" 12 "strconv" 13 "strings" 14 "sync" 15 16 "github.com/metacubex/mihomo/common/atomic" 17 CN "github.com/metacubex/mihomo/common/net" 18 "github.com/metacubex/mihomo/component/dialer" 19 "github.com/metacubex/mihomo/component/proxydialer" 20 "github.com/metacubex/mihomo/component/resolver" 21 "github.com/metacubex/mihomo/component/slowdown" 22 C "github.com/metacubex/mihomo/constant" 23 "github.com/metacubex/mihomo/dns" 24 "github.com/metacubex/mihomo/log" 25 26 wireguard "github.com/metacubex/sing-wireguard" 27 28 "github.com/sagernet/sing/common" 29 "github.com/sagernet/sing/common/debug" 30 E "github.com/sagernet/sing/common/exceptions" 31 M "github.com/sagernet/sing/common/metadata" 32 "github.com/sagernet/wireguard-go/device" 33 ) 34 35 type WireGuard struct { 36 *Base 37 bind *wireguard.ClientBind 38 device *device.Device 39 tunDevice wireguard.Device 40 dialer proxydialer.SingDialer 41 resolver *dns.Resolver 42 refP *refProxyAdapter 43 44 initOk atomic.Bool 45 initMutex sync.Mutex 46 initErr error 47 option WireGuardOption 48 connectAddr M.Socksaddr 49 localPrefixes []netip.Prefix 50 51 closeCh chan struct{} // for test 52 } 53 54 type WireGuardOption struct { 55 BasicOption 56 WireGuardPeerOption 57 Name string `proxy:"name"` 58 Ip string `proxy:"ip,omitempty"` 59 Ipv6 string `proxy:"ipv6,omitempty"` 60 PrivateKey string `proxy:"private-key"` 61 Workers int `proxy:"workers,omitempty"` 62 MTU int `proxy:"mtu,omitempty"` 63 UDP bool `proxy:"udp,omitempty"` 64 PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` 65 66 Peers []WireGuardPeerOption `proxy:"peers,omitempty"` 67 68 RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"` 69 Dns []string `proxy:"dns,omitempty"` 70 } 71 72 type WireGuardPeerOption struct { 73 Server string `proxy:"server"` 74 Port int `proxy:"port"` 75 PublicKey string `proxy:"public-key,omitempty"` 76 PreSharedKey string `proxy:"pre-shared-key,omitempty"` 77 Reserved []uint8 `proxy:"reserved,omitempty"` 78 AllowedIPs []string `proxy:"allowed-ips,omitempty"` 79 } 80 81 type wgSingErrorHandler struct { 82 name string 83 } 84 85 var _ E.Handler = (*wgSingErrorHandler)(nil) 86 87 func (w wgSingErrorHandler) NewError(ctx context.Context, err error) { 88 if E.IsClosedOrCanceled(err) { 89 log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.name, err)) 90 return 91 } 92 log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.name, err)) 93 } 94 95 type wgNetDialer struct { 96 tunDevice wireguard.Device 97 } 98 99 var _ dialer.NetDialer = (*wgNetDialer)(nil) 100 101 func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 102 return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap()) 103 } 104 105 func (option WireGuardPeerOption) Addr() M.Socksaddr { 106 return M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)) 107 } 108 109 func (option WireGuardOption) Prefixes() ([]netip.Prefix, error) { 110 localPrefixes := make([]netip.Prefix, 0, 2) 111 if len(option.Ip) > 0 { 112 if !strings.Contains(option.Ip, "/") { 113 option.Ip = option.Ip + "/32" 114 } 115 if prefix, err := netip.ParsePrefix(option.Ip); err == nil { 116 localPrefixes = append(localPrefixes, prefix) 117 } else { 118 return nil, E.Cause(err, "ip address parse error") 119 } 120 } 121 if len(option.Ipv6) > 0 { 122 if !strings.Contains(option.Ipv6, "/") { 123 option.Ipv6 = option.Ipv6 + "/128" 124 } 125 if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil { 126 localPrefixes = append(localPrefixes, prefix) 127 } else { 128 return nil, E.Cause(err, "ipv6 address parse error") 129 } 130 } 131 if len(localPrefixes) == 0 { 132 return nil, E.New("missing local address") 133 } 134 return localPrefixes, nil 135 } 136 137 func NewWireGuard(option WireGuardOption) (*WireGuard, error) { 138 outbound := &WireGuard{ 139 Base: &Base{ 140 name: option.Name, 141 addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), 142 tp: C.WireGuard, 143 udp: option.UDP, 144 iface: option.Interface, 145 rmark: option.RoutingMark, 146 prefer: C.NewDNSPrefer(option.IPVersion), 147 }, 148 dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()), 149 } 150 runtime.SetFinalizer(outbound, closeWireGuard) 151 152 var reserved [3]uint8 153 if len(option.Reserved) > 0 { 154 if len(option.Reserved) != 3 { 155 return nil, E.New("invalid reserved value, required 3 bytes, got ", len(option.Reserved)) 156 } 157 copy(reserved[:], option.Reserved) 158 } 159 var isConnect bool 160 if len(option.Peers) < 2 { 161 isConnect = true 162 if len(option.Peers) == 1 { 163 outbound.connectAddr = option.Peers[0].Addr() 164 } else { 165 outbound.connectAddr = option.Addr() 166 } 167 } 168 outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved) 169 170 var err error 171 outbound.localPrefixes, err = option.Prefixes() 172 if err != nil { 173 return nil, err 174 } 175 176 { 177 bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) 178 if err != nil { 179 return nil, E.Cause(err, "decode private key") 180 } 181 option.PrivateKey = hex.EncodeToString(bytes) 182 } 183 184 if len(option.Peers) > 0 { 185 for i := range option.Peers { 186 peer := &option.Peers[i] // we need modify option here 187 bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey) 188 if err != nil { 189 return nil, E.Cause(err, "decode public key for peer ", i) 190 } 191 peer.PublicKey = hex.EncodeToString(bytes) 192 193 if peer.PreSharedKey != "" { 194 bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey) 195 if err != nil { 196 return nil, E.Cause(err, "decode pre shared key for peer ", i) 197 } 198 peer.PreSharedKey = hex.EncodeToString(bytes) 199 } 200 201 if len(peer.AllowedIPs) == 0 { 202 return nil, E.New("missing allowed_ips for peer ", i) 203 } 204 205 if len(peer.Reserved) > 0 { 206 if len(peer.Reserved) != 3 { 207 return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved)) 208 } 209 } 210 } 211 } else { 212 { 213 bytes, err := base64.StdEncoding.DecodeString(option.PublicKey) 214 if err != nil { 215 return nil, E.Cause(err, "decode peer public key") 216 } 217 option.PublicKey = hex.EncodeToString(bytes) 218 } 219 if option.PreSharedKey != "" { 220 bytes, err := base64.StdEncoding.DecodeString(option.PreSharedKey) 221 if err != nil { 222 return nil, E.Cause(err, "decode pre shared key") 223 } 224 option.PreSharedKey = hex.EncodeToString(bytes) 225 } 226 } 227 outbound.option = option 228 229 mtu := option.MTU 230 if mtu == 0 { 231 mtu = 1408 232 } 233 if len(outbound.localPrefixes) == 0 { 234 return nil, E.New("missing local address") 235 } 236 outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu)) 237 if err != nil { 238 return nil, E.Cause(err, "create WireGuard device") 239 } 240 outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{ 241 Verbosef: func(format string, args ...interface{}) { 242 log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) 243 }, 244 Errorf: func(format string, args ...interface{}) { 245 log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) 246 }, 247 }, option.Workers) 248 249 var has6 bool 250 for _, address := range outbound.localPrefixes { 251 if !address.Addr().Unmap().Is4() { 252 has6 = true 253 break 254 } 255 } 256 257 refP := &refProxyAdapter{} 258 outbound.refP = refP 259 if option.RemoteDnsResolve && len(option.Dns) > 0 { 260 nss, err := dns.ParseNameServer(option.Dns) 261 if err != nil { 262 return nil, err 263 } 264 for i := range nss { 265 nss[i].ProxyAdapter = refP 266 } 267 outbound.resolver = dns.NewResolver(dns.Config{ 268 Main: nss, 269 IPv6: has6, 270 }) 271 } 272 273 return outbound, nil 274 } 275 276 func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) { 277 if address.Addr.IsValid() { 278 return address.AddrPort(), nil 279 } 280 udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer) 281 if err != nil { 282 return netip.AddrPort{}, err 283 } 284 // net.ResolveUDPAddr maybe return 4in6 address, so unmap at here 285 addrPort := udpAddr.AddrPort() 286 return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil 287 } 288 289 func (w *WireGuard) init(ctx context.Context) error { 290 if w.initOk.Load() { 291 return nil 292 } 293 w.initMutex.Lock() 294 defer w.initMutex.Unlock() 295 // double check like sync.Once 296 if w.initOk.Load() { 297 return nil 298 } 299 if w.initErr != nil { 300 return w.initErr 301 } 302 303 w.bind.ResetReservedForEndpoint() 304 ipcConf := "private_key=" + w.option.PrivateKey 305 if len(w.option.Peers) > 0 { 306 for i, peer := range w.option.Peers { 307 destination, err := w.resolve(ctx, peer.Addr()) 308 if err != nil { 309 // !!! do not set initErr here !!! 310 // let us can retry domain resolve in next time 311 return E.Cause(err, "resolve endpoint domain for peer ", i) 312 } 313 ipcConf += "\npublic_key=" + peer.PublicKey 314 ipcConf += "\nendpoint=" + destination.String() 315 if peer.PreSharedKey != "" { 316 ipcConf += "\npreshared_key=" + peer.PreSharedKey 317 } 318 for _, allowedIP := range peer.AllowedIPs { 319 ipcConf += "\nallowed_ip=" + allowedIP 320 } 321 if len(peer.Reserved) > 0 { 322 var reserved [3]uint8 323 copy(reserved[:], w.option.Reserved) 324 w.bind.SetReservedForEndpoint(destination, reserved) 325 } 326 } 327 } else { 328 ipcConf += "\npublic_key=" + w.option.PublicKey 329 destination, err := w.resolve(ctx, w.connectAddr) 330 if err != nil { 331 // !!! do not set initErr here !!! 332 // let us can retry domain resolve in next time 333 return E.Cause(err, "resolve endpoint domain") 334 } 335 w.bind.SetConnectAddr(destination) 336 ipcConf += "\nendpoint=" + destination.String() 337 if w.option.PreSharedKey != "" { 338 ipcConf += "\npreshared_key=" + w.option.PreSharedKey 339 } 340 var has4, has6 bool 341 for _, address := range w.localPrefixes { 342 if address.Addr().Is4() { 343 has4 = true 344 } else { 345 has6 = true 346 } 347 } 348 if has4 { 349 ipcConf += "\nallowed_ip=0.0.0.0/0" 350 } 351 if has6 { 352 ipcConf += "\nallowed_ip=::/0" 353 } 354 } 355 356 if w.option.PersistentKeepalive != 0 { 357 ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", w.option.PersistentKeepalive) 358 } 359 360 if debug.Enabled { 361 log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf)) 362 } 363 err := w.device.IpcSet(ipcConf) 364 if err != nil { 365 w.initErr = E.Cause(err, "setup wireguard") 366 return w.initErr 367 } 368 369 err = w.tunDevice.Start() 370 if err != nil { 371 w.initErr = err 372 return w.initErr 373 } 374 375 w.initOk.Store(true) 376 return nil 377 } 378 379 func closeWireGuard(w *WireGuard) { 380 if w.device != nil { 381 w.device.Close() 382 } 383 _ = common.Close(w.tunDevice) 384 if w.closeCh != nil { 385 close(w.closeCh) 386 } 387 } 388 389 func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { 390 options := w.Base.DialOptions(opts...) 391 w.dialer.SetDialer(dialer.NewDialer(options...)) 392 var conn net.Conn 393 if err = w.init(ctx); err != nil { 394 return nil, err 395 } 396 if !metadata.Resolved() || w.resolver != nil { 397 r := resolver.DefaultResolver 398 if w.resolver != nil { 399 w.refP.SetProxyAdapter(w) 400 defer w.refP.ClearProxyAdapter() 401 r = w.resolver 402 } 403 options = append(options, dialer.WithResolver(r)) 404 options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice})) 405 conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress()) 406 } else { 407 conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) 408 } 409 if err != nil { 410 return nil, err 411 } 412 if conn == nil { 413 return nil, E.New("conn is nil") 414 } 415 return NewConn(CN.NewRefConn(conn, w), w), nil 416 } 417 418 func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) { 419 options := w.Base.DialOptions(opts...) 420 w.dialer.SetDialer(dialer.NewDialer(options...)) 421 var pc net.PacketConn 422 if err = w.init(ctx); err != nil { 423 return nil, err 424 } 425 if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { 426 r := resolver.DefaultResolver 427 if w.resolver != nil { 428 w.refP.SetProxyAdapter(w) 429 defer w.refP.ClearProxyAdapter() 430 r = w.resolver 431 } 432 ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) 433 if err != nil { 434 return nil, errors.New("can't resolve ip") 435 } 436 metadata.DstIP = ip 437 } 438 pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) 439 if err != nil { 440 return nil, err 441 } 442 if pc == nil { 443 return nil, E.New("packetConn is nil") 444 } 445 return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil 446 } 447 448 // IsL3Protocol implements C.ProxyAdapter 449 func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { 450 return true 451 } 452 453 type refProxyAdapter struct { 454 proxyAdapter C.ProxyAdapter 455 count int 456 mutex sync.Mutex 457 } 458 459 func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) { 460 r.mutex.Lock() 461 defer r.mutex.Unlock() 462 r.proxyAdapter = proxyAdapter 463 r.count++ 464 } 465 466 func (r *refProxyAdapter) ClearProxyAdapter() { 467 r.mutex.Lock() 468 defer r.mutex.Unlock() 469 r.count-- 470 if r.count == 0 { 471 r.proxyAdapter = nil 472 } 473 } 474 475 func (r *refProxyAdapter) Name() string { 476 if r.proxyAdapter != nil { 477 return r.proxyAdapter.Name() 478 } 479 return "" 480 } 481 482 func (r *refProxyAdapter) Type() C.AdapterType { 483 if r.proxyAdapter != nil { 484 return r.proxyAdapter.Type() 485 } 486 return C.AdapterType(0) 487 } 488 489 func (r *refProxyAdapter) Addr() string { 490 if r.proxyAdapter != nil { 491 return r.proxyAdapter.Addr() 492 } 493 return "" 494 } 495 496 func (r *refProxyAdapter) SupportUDP() bool { 497 if r.proxyAdapter != nil { 498 return r.proxyAdapter.SupportUDP() 499 } 500 return false 501 } 502 503 func (r *refProxyAdapter) SupportXUDP() bool { 504 if r.proxyAdapter != nil { 505 return r.proxyAdapter.SupportXUDP() 506 } 507 return false 508 } 509 510 func (r *refProxyAdapter) SupportTFO() bool { 511 if r.proxyAdapter != nil { 512 return r.proxyAdapter.SupportTFO() 513 } 514 return false 515 } 516 517 func (r *refProxyAdapter) MarshalJSON() ([]byte, error) { 518 if r.proxyAdapter != nil { 519 return r.proxyAdapter.MarshalJSON() 520 } 521 return nil, C.ErrNotSupport 522 } 523 524 func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { 525 if r.proxyAdapter != nil { 526 return r.proxyAdapter.StreamConnContext(ctx, c, metadata) 527 } 528 return nil, C.ErrNotSupport 529 } 530 531 func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { 532 if r.proxyAdapter != nil { 533 return r.proxyAdapter.DialContext(ctx, metadata, opts...) 534 } 535 return nil, C.ErrNotSupport 536 } 537 538 func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { 539 if r.proxyAdapter != nil { 540 return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...) 541 } 542 return nil, C.ErrNotSupport 543 } 544 545 func (r *refProxyAdapter) SupportUOT() bool { 546 if r.proxyAdapter != nil { 547 return r.proxyAdapter.SupportUOT() 548 } 549 return false 550 } 551 552 func (r *refProxyAdapter) SupportWithDialer() C.NetWork { 553 if r.proxyAdapter != nil { 554 return r.proxyAdapter.SupportWithDialer() 555 } 556 return C.InvalidNet 557 } 558 559 func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) { 560 if r.proxyAdapter != nil { 561 return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata) 562 } 563 return nil, C.ErrNotSupport 564 } 565 566 func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) { 567 if r.proxyAdapter != nil { 568 return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata) 569 } 570 return nil, C.ErrNotSupport 571 } 572 573 func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool { 574 if r.proxyAdapter != nil { 575 return r.proxyAdapter.IsL3Protocol(metadata) 576 } 577 return false 578 } 579 580 func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy { 581 if r.proxyAdapter != nil { 582 return r.proxyAdapter.Unwrap(metadata, touch) 583 } 584 return nil 585 } 586 587 var _ C.ProxyAdapter = (*refProxyAdapter)(nil)