github.com/yaling888/clash@v1.53.0/adapter/outbound/wireguard_gvsior.go (about)

     1  //go:build !nogvisor
     2  
     3  package outbound
     4  
     5  import (
     6  	"context"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"math/rand/v2"
    12  	"net"
    13  	"net/netip"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  	_ "unsafe"
    20  
    21  	"github.com/phuslu/log"
    22  	"github.com/samber/lo"
    23  	bind "golang.zx2c4.com/wireguard/conn"
    24  	"golang.zx2c4.com/wireguard/device"
    25  	"golang.zx2c4.com/wireguard/tun"
    26  
    27  	"github.com/yaling888/clash/component/dialer"
    28  	"github.com/yaling888/clash/component/iface"
    29  	"github.com/yaling888/clash/component/resolver"
    30  	C "github.com/yaling888/clash/constant"
    31  	"github.com/yaling888/clash/transport/wireguard"
    32  )
    33  
    34  //go:linkname controlFns golang.zx2c4.com/wireguard/conn.controlFns
    35  var controlFns []func(network, address string, c syscall.RawConn) error
    36  
    37  const dialTimeout = 10 * time.Second
    38  
    39  var _ C.ProxyAdapter = (*WireGuard)(nil)
    40  
    41  type WireGuard struct {
    42  	*Base
    43  	wgDevice  *device.Device
    44  	tunDevice tun.Device
    45  	netStack  *wireguard.Net
    46  	bind      bind.Bind
    47  
    48  	localIP    netip.Addr
    49  	localIPv6  netip.Addr
    50  	dnsServers []netip.Addr
    51  	reserved   []byte
    52  	uapiConf   []string
    53  	threadId   string
    54  	mtu        int
    55  	hasV6      bool
    56  
    57  	upOnce   sync.Once
    58  	downOnce sync.Once
    59  	upErr    error
    60  
    61  	remoteDnsResolve bool
    62  }
    63  
    64  // DialContext implements C.ProxyAdapter
    65  func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.Conn, error) {
    66  	w.up()
    67  	if w.upErr != nil {
    68  		return nil, fmt.Errorf("apply wireguard proxy %s config error: %w", w.threadId, w.upErr)
    69  	}
    70  
    71  	dialCtx := ctx
    72  	if _, hasDeadline := ctx.Deadline(); !hasDeadline {
    73  		var cancel context.CancelFunc
    74  		dialCtx, cancel = context.WithDeadline(ctx, time.Now().Add(dialTimeout))
    75  		defer cancel()
    76  	}
    77  
    78  	if err := w.resolveDNS(metadata, false); err != nil {
    79  		return nil, fmt.Errorf("resolve DNS failed: %w", err)
    80  	}
    81  
    82  	c, err := w.netStack.DialContextTCPAddrPort(dialCtx, netip.AddrPortFrom(metadata.DstIP, uint16(metadata.DstPort)))
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	if c == nil {
    87  		return nil, errors.New("conn is nil")
    88  	}
    89  	return NewConn(&wgConn{c}, w), nil
    90  }
    91  
    92  // ListenPacketContext implements C.ProxyAdapter
    93  func (w *WireGuard) ListenPacketContext(_ context.Context, metadata *C.Metadata, _ ...dialer.Option) (C.PacketConn, error) {
    94  	w.up()
    95  	if w.upErr != nil {
    96  		return nil, fmt.Errorf("apply wireguard proxy %s config failure, cause: %w", w.threadId, w.upErr)
    97  	}
    98  
    99  	if err := w.resolveDNS(metadata, true); err != nil {
   100  		return nil, fmt.Errorf("resolve DNS failed: %w", err)
   101  	}
   102  
   103  	var lAddr netip.Addr
   104  	if metadata.DstIP.Is6() {
   105  		lAddr = w.localIPv6
   106  	} else {
   107  		lAddr = w.localIP
   108  	}
   109  
   110  	pc, err := w.netStack.ListenUDPAddrPort(netip.AddrPortFrom(lAddr, 0))
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	if pc == nil {
   115  		return nil, errors.New("packetConn is nil")
   116  	}
   117  	return NewPacketConn(&wgPConn{pc}, w), nil
   118  }
   119  
   120  // Cleanup implements C.Cleanup
   121  func (w *WireGuard) Cleanup() {
   122  	w.downOnce.Do(func() {
   123  		if w.wgDevice != nil {
   124  			w.wgDevice.Close()
   125  		}
   126  	})
   127  }
   128  
   129  // DisableDnsResolve implements C.DisableDnsResolve
   130  func (w *WireGuard) DisableDnsResolve() bool {
   131  	return true // let WireGuard resolve it
   132  }
   133  
   134  func (w *WireGuard) UpdateBind() {
   135  	if w.bind == nil || w.wgDevice == nil {
   136  		return
   137  	}
   138  	if s, ok := w.bind.(*wireguard.StdNetBind); ok {
   139  		s.UpdateControlFns(getBindControlFns(w.Base.name))
   140  	}
   141  
   142  	_ = w.wgDevice.BindUpdate()
   143  	_ = w.bindSocketToInterface()
   144  }
   145  
   146  // bindSocketToInterface used by WinRingBind
   147  func (w *WireGuard) bindSocketToInterface() error {
   148  	if b, ok := w.bind.(bind.BindSocketToInterface); ok {
   149  		interfaceName := getInterfaceName(w.Base.iface)
   150  		if interfaceName == "" {
   151  			return nil
   152  		}
   153  		obj, err := iface.ResolveInterface(interfaceName)
   154  		if err != nil {
   155  			return err
   156  		}
   157  		_ = b.BindSocketToInterface4(uint32(obj.Index), false)
   158  		_ = b.BindSocketToInterface6(uint32(obj.Index), false)
   159  	}
   160  	return nil
   161  }
   162  
   163  func (w *WireGuard) resolveDNS(metadata *C.Metadata, udp bool) error {
   164  	if metadata.Host == "" {
   165  		return nil
   166  	}
   167  	if w.remoteDnsResolve {
   168  		var (
   169  			rAddrs []netip.Addr
   170  			err    error
   171  		)
   172  		if w.hasV6 {
   173  			rAddrs, err = resolver.LookupIPByProxy(context.Background(), metadata.Host, w.name)
   174  		} else {
   175  			rAddrs, err = resolver.LookupIPv4ByProxy(context.Background(), metadata.Host, w.name)
   176  		}
   177  		if err != nil {
   178  			return err
   179  		}
   180  		if udp {
   181  			metadata.DstIP = rAddrs[0]
   182  		} else {
   183  			if w.hasV6 {
   184  				v6 := lo.Filter(rAddrs, func(addr netip.Addr, _ int) bool {
   185  					return addr.Is6()
   186  				})
   187  				if len(v6) > 0 {
   188  					rAddrs = v6
   189  				}
   190  			}
   191  			metadata.DstIP = rAddrs[rand.IntN(len(rAddrs))]
   192  		}
   193  	} else if !metadata.Resolved() {
   194  		var (
   195  			rAddrs []netip.Addr
   196  			err    error
   197  		)
   198  		if w.hasV6 {
   199  			rAddrs, err = resolver.LookupIP(context.Background(), metadata.Host)
   200  		} else {
   201  			rAddrs, err = resolver.LookupIPv4(context.Background(), metadata.Host)
   202  		}
   203  		if err != nil {
   204  			return err
   205  		}
   206  		if udp {
   207  			metadata.DstIP = rAddrs[0]
   208  		} else {
   209  			metadata.DstIP = rAddrs[rand.IntN(len(rAddrs))]
   210  		}
   211  	}
   212  	return nil
   213  }
   214  
   215  func (w *WireGuard) up() {
   216  	w.upOnce.Do(func() {
   217  		w.upErr = w.init()
   218  	})
   219  }
   220  
   221  func (w *WireGuard) init() error {
   222  	host, port, _ := net.SplitHostPort(w.Base.Addr())
   223  	tryTimes := 0
   224  
   225  lookup:
   226  	endpointIP, err := resolver.ResolveProxyServerHost(host)
   227  	if err != nil {
   228  		if tryTimes < 5 {
   229  			tryTimes++
   230  			time.Sleep(2 * time.Second)
   231  			goto lookup
   232  		}
   233  		return fmt.Errorf("parse server endpoint [%s] failure, cause: %w", w.Base.Addr(), err)
   234  	}
   235  
   236  	p, _ := strconv.ParseUint(port, 10, 16)
   237  	endpoint := netip.AddrPortFrom(endpointIP, uint16(p))
   238  	w.uapiConf = append(w.uapiConf, fmt.Sprintf("endpoint=%s", endpoint))
   239  
   240  	localIPs := make([]netip.Addr, 0, 2)
   241  	if w.localIP.IsValid() {
   242  		localIPs = append(localIPs, w.localIP)
   243  	}
   244  	if w.localIPv6.IsValid() {
   245  		w.hasV6 = true
   246  		localIPs = append(localIPs, w.localIPv6)
   247  	}
   248  
   249  	tunDevice, netStack, err := wireguard.CreateNetTUN(localIPs, w.dnsServers, w.mtu)
   250  	if err != nil {
   251  		return err
   252  	}
   253  
   254  	wgBind := wireguard.NewDefaultBind(getBindControlFns(w.Base.iface), w.Base.iface, w.reserved)
   255  	w.bind = wgBind
   256  
   257  	logger := &device.Logger{
   258  		Verbosef: func(format string, args ...any) {
   259  			log.Debug().Msgf("[WireGuard] [%s] "+strings.ToLower(format), append([]any{w.threadId}, args...)...)
   260  		},
   261  		Errorf: func(format string, args ...any) {
   262  			log.Error().Msgf("[WireGuard] [%s] "+strings.ToLower(format), append([]any{w.threadId}, args...)...)
   263  		},
   264  	}
   265  
   266  	wgDevice := device.NewDevice(tunDevice, wgBind, logger)
   267  
   268  	log.Debug().Strs("config", w.uapiConf).Msgf("[WireGuard] initial wireguard proxy %s", w.threadId)
   269  
   270  	err = wgDevice.IpcSet(strings.Join(w.uapiConf, "\n"))
   271  	if err != nil {
   272  		return err
   273  	}
   274  
   275  	_ = w.bindSocketToInterface()
   276  
   277  	w.tunDevice = tunDevice
   278  	w.netStack = netStack
   279  	w.wgDevice = wgDevice
   280  	w.uapiConf = nil
   281  	w.dnsServers = nil
   282  	w.reserved = nil
   283  	return nil
   284  }
   285  
   286  func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
   287  	uapiConf := make([]string, 0, 6)
   288  	privateKeyBytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
   289  	if err != nil {
   290  		return nil, fmt.Errorf("decode wireguard private key failure, cause: %w", err)
   291  	}
   292  	uapiConf = append(uapiConf, fmt.Sprintf("private_key=%s", hex.EncodeToString(privateKeyBytes)))
   293  
   294  	publicKeyBytes, err := base64.StdEncoding.DecodeString(option.PublicKey)
   295  	if err != nil {
   296  		return nil, fmt.Errorf("decode wireguard peer public key failure, cause: %w", err)
   297  	}
   298  	uapiConf = append(uapiConf, fmt.Sprintf("public_key=%s", hex.EncodeToString(publicKeyBytes)))
   299  
   300  	if option.PresharedKey != "" {
   301  		bytes, err := base64.StdEncoding.DecodeString(option.PresharedKey)
   302  		if err != nil {
   303  			return nil, fmt.Errorf("decode wireguard preshared key failure, cause: %w", err)
   304  		}
   305  		uapiConf = append(uapiConf, fmt.Sprintf("preshared_key=%s", hex.EncodeToString(bytes)))
   306  	}
   307  
   308  	var reservedBytes []byte
   309  	if option.Reserved != "" {
   310  		reserved := strings.TrimPrefix(strings.ToLower(option.Reserved), "0x")
   311  		if reservedBytes, err = hex.DecodeString(reserved); err != nil || len(reservedBytes) != 3 {
   312  			return nil, fmt.Errorf("decode wireguard reserved 3 bytes failure %w", err)
   313  		}
   314  	}
   315  
   316  	var (
   317  		localIP   netip.Addr
   318  		localIPv6 netip.Addr
   319  	)
   320  	if option.IP != "" {
   321  		option.IP, _, _ = strings.Cut(option.IP, "/")
   322  		if localIP, err = netip.ParseAddr(option.IP); err != nil {
   323  			return nil, fmt.Errorf("parse wireguard ip address failure, cause: %w", err)
   324  		}
   325  	}
   326  
   327  	if option.IPv6 != "" {
   328  		option.IPv6, _, _ = strings.Cut(option.IPv6, "/")
   329  		if localIPv6, err = netip.ParseAddr(option.IPv6); err != nil {
   330  			return nil, fmt.Errorf("parse wireguard ipv6 address failure, cause: %w", err)
   331  		}
   332  	}
   333  
   334  	if !localIP.IsValid() && !localIPv6.IsValid() {
   335  		return nil, errors.New("wireguard missing local ip")
   336  	}
   337  
   338  	dns := option.DNS
   339  	if len(dns) == 0 {
   340  		dns = append(dns, "1.1.1.1", "8.8.8.8")
   341  	}
   342  	dnsServers := make([]netip.Addr, len(dns))
   343  	for _, d := range dns {
   344  		if ip, err1 := netip.ParseAddr(d); err1 != nil {
   345  			return nil, fmt.Errorf("parse wireguard dns address failure, cause: %w", err1)
   346  		} else {
   347  			dnsServers = append(dnsServers, ip)
   348  		}
   349  	}
   350  
   351  	if localIP.IsValid() {
   352  		uapiConf = append(uapiConf, "allowed_ip=0.0.0.0/0")
   353  	}
   354  	if localIPv6.IsValid() {
   355  		uapiConf = append(uapiConf, "allowed_ip=::/0")
   356  	}
   357  
   358  	mtu := option.MTU
   359  	if mtu == 0 {
   360  		mtu = 1408
   361  	}
   362  
   363  	threadId := fmt.Sprintf("%s-%d", option.Name, rand.IntN(100))
   364  
   365  	base := &Base{
   366  		name:  option.Name,
   367  		addr:  net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
   368  		tp:    C.WireGuard,
   369  		udp:   option.UDP,
   370  		iface: option.Interface,
   371  		rmark: option.RoutingMark,
   372  	}
   373  	wireGuard := &WireGuard{
   374  		Base:       base,
   375  		localIP:    localIP,
   376  		localIPv6:  localIPv6,
   377  		dnsServers: dnsServers,
   378  		reserved:   reservedBytes,
   379  		uapiConf:   uapiConf,
   380  		threadId:   threadId,
   381  		mtu:        mtu,
   382  
   383  		remoteDnsResolve: option.RemoteDnsResolve,
   384  	}
   385  	return wireGuard, nil
   386  }
   387  
   388  // getBindControlFns used by StdNetBind
   389  func getBindControlFns(interfaceName string) []func(network, address string, c syscall.RawConn) error {
   390  	var bindFns []func(network, address string, c syscall.RawConn) error
   391  
   392  	bindFns = append(bindFns, controlFns...)
   393  	bindFns = append(bindFns, dialer.WithBindToInterfaceControlFn(getInterfaceName(interfaceName)))
   394  
   395  	return bindFns
   396  }
   397  
   398  func getInterfaceName(interfaceName string) string {
   399  	if interfaceName == "" {
   400  		interfaceName = dialer.DefaultInterface.Load()
   401  	}
   402  	return interfaceName
   403  }
   404  
   405  type wgConn struct {
   406  	net.Conn
   407  }
   408  
   409  type wgPConn struct {
   410  	net.PacketConn
   411  }