github.com/metacubex/mihomo@v1.18.5/adapter/outbound/wireguard.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/netip"
    11  	"runtime"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  
    16  	"github.com/metacubex/mihomo/common/atomic"
    17  	CN "github.com/metacubex/mihomo/common/net"
    18  	"github.com/metacubex/mihomo/component/dialer"
    19  	"github.com/metacubex/mihomo/component/proxydialer"
    20  	"github.com/metacubex/mihomo/component/resolver"
    21  	"github.com/metacubex/mihomo/component/slowdown"
    22  	C "github.com/metacubex/mihomo/constant"
    23  	"github.com/metacubex/mihomo/dns"
    24  	"github.com/metacubex/mihomo/log"
    25  
    26  	wireguard "github.com/metacubex/sing-wireguard"
    27  
    28  	"github.com/sagernet/sing/common"
    29  	"github.com/sagernet/sing/common/debug"
    30  	E "github.com/sagernet/sing/common/exceptions"
    31  	M "github.com/sagernet/sing/common/metadata"
    32  	"github.com/sagernet/wireguard-go/device"
    33  )
    34  
    35  type WireGuard struct {
    36  	*Base
    37  	bind      *wireguard.ClientBind
    38  	device    *device.Device
    39  	tunDevice wireguard.Device
    40  	dialer    proxydialer.SingDialer
    41  	resolver  *dns.Resolver
    42  	refP      *refProxyAdapter
    43  
    44  	initOk        atomic.Bool
    45  	initMutex     sync.Mutex
    46  	initErr       error
    47  	option        WireGuardOption
    48  	connectAddr   M.Socksaddr
    49  	localPrefixes []netip.Prefix
    50  
    51  	closeCh chan struct{} // for test
    52  }
    53  
    54  type WireGuardOption struct {
    55  	BasicOption
    56  	WireGuardPeerOption
    57  	Name                string `proxy:"name"`
    58  	Ip                  string `proxy:"ip,omitempty"`
    59  	Ipv6                string `proxy:"ipv6,omitempty"`
    60  	PrivateKey          string `proxy:"private-key"`
    61  	Workers             int    `proxy:"workers,omitempty"`
    62  	MTU                 int    `proxy:"mtu,omitempty"`
    63  	UDP                 bool   `proxy:"udp,omitempty"`
    64  	PersistentKeepalive int    `proxy:"persistent-keepalive,omitempty"`
    65  
    66  	Peers []WireGuardPeerOption `proxy:"peers,omitempty"`
    67  
    68  	RemoteDnsResolve bool     `proxy:"remote-dns-resolve,omitempty"`
    69  	Dns              []string `proxy:"dns,omitempty"`
    70  }
    71  
    72  type WireGuardPeerOption struct {
    73  	Server       string   `proxy:"server"`
    74  	Port         int      `proxy:"port"`
    75  	PublicKey    string   `proxy:"public-key,omitempty"`
    76  	PreSharedKey string   `proxy:"pre-shared-key,omitempty"`
    77  	Reserved     []uint8  `proxy:"reserved,omitempty"`
    78  	AllowedIPs   []string `proxy:"allowed-ips,omitempty"`
    79  }
    80  
    81  type wgSingErrorHandler struct {
    82  	name string
    83  }
    84  
    85  var _ E.Handler = (*wgSingErrorHandler)(nil)
    86  
    87  func (w wgSingErrorHandler) NewError(ctx context.Context, err error) {
    88  	if E.IsClosedOrCanceled(err) {
    89  		log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.name, err))
    90  		return
    91  	}
    92  	log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.name, err))
    93  }
    94  
    95  type wgNetDialer struct {
    96  	tunDevice wireguard.Device
    97  }
    98  
    99  var _ dialer.NetDialer = (*wgNetDialer)(nil)
   100  
   101  func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   102  	return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap())
   103  }
   104  
   105  func (option WireGuardPeerOption) Addr() M.Socksaddr {
   106  	return M.ParseSocksaddrHostPort(option.Server, uint16(option.Port))
   107  }
   108  
   109  func (option WireGuardOption) Prefixes() ([]netip.Prefix, error) {
   110  	localPrefixes := make([]netip.Prefix, 0, 2)
   111  	if len(option.Ip) > 0 {
   112  		if !strings.Contains(option.Ip, "/") {
   113  			option.Ip = option.Ip + "/32"
   114  		}
   115  		if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
   116  			localPrefixes = append(localPrefixes, prefix)
   117  		} else {
   118  			return nil, E.Cause(err, "ip address parse error")
   119  		}
   120  	}
   121  	if len(option.Ipv6) > 0 {
   122  		if !strings.Contains(option.Ipv6, "/") {
   123  			option.Ipv6 = option.Ipv6 + "/128"
   124  		}
   125  		if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
   126  			localPrefixes = append(localPrefixes, prefix)
   127  		} else {
   128  			return nil, E.Cause(err, "ipv6 address parse error")
   129  		}
   130  	}
   131  	if len(localPrefixes) == 0 {
   132  		return nil, E.New("missing local address")
   133  	}
   134  	return localPrefixes, nil
   135  }
   136  
   137  func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
   138  	outbound := &WireGuard{
   139  		Base: &Base{
   140  			name:   option.Name,
   141  			addr:   net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
   142  			tp:     C.WireGuard,
   143  			udp:    option.UDP,
   144  			iface:  option.Interface,
   145  			rmark:  option.RoutingMark,
   146  			prefer: C.NewDNSPrefer(option.IPVersion),
   147  		},
   148  		dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()),
   149  	}
   150  	runtime.SetFinalizer(outbound, closeWireGuard)
   151  
   152  	var reserved [3]uint8
   153  	if len(option.Reserved) > 0 {
   154  		if len(option.Reserved) != 3 {
   155  			return nil, E.New("invalid reserved value, required 3 bytes, got ", len(option.Reserved))
   156  		}
   157  		copy(reserved[:], option.Reserved)
   158  	}
   159  	var isConnect bool
   160  	if len(option.Peers) < 2 {
   161  		isConnect = true
   162  		if len(option.Peers) == 1 {
   163  			outbound.connectAddr = option.Peers[0].Addr()
   164  		} else {
   165  			outbound.connectAddr = option.Addr()
   166  		}
   167  	}
   168  	outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved)
   169  
   170  	var err error
   171  	outbound.localPrefixes, err = option.Prefixes()
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	{
   177  		bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
   178  		if err != nil {
   179  			return nil, E.Cause(err, "decode private key")
   180  		}
   181  		option.PrivateKey = hex.EncodeToString(bytes)
   182  	}
   183  
   184  	if len(option.Peers) > 0 {
   185  		for i := range option.Peers {
   186  			peer := &option.Peers[i] // we need modify option here
   187  			bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey)
   188  			if err != nil {
   189  				return nil, E.Cause(err, "decode public key for peer ", i)
   190  			}
   191  			peer.PublicKey = hex.EncodeToString(bytes)
   192  
   193  			if peer.PreSharedKey != "" {
   194  				bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey)
   195  				if err != nil {
   196  					return nil, E.Cause(err, "decode pre shared key for peer ", i)
   197  				}
   198  				peer.PreSharedKey = hex.EncodeToString(bytes)
   199  			}
   200  
   201  			if len(peer.AllowedIPs) == 0 {
   202  				return nil, E.New("missing allowed_ips for peer ", i)
   203  			}
   204  
   205  			if len(peer.Reserved) > 0 {
   206  				if len(peer.Reserved) != 3 {
   207  					return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved))
   208  				}
   209  			}
   210  		}
   211  	} else {
   212  		{
   213  			bytes, err := base64.StdEncoding.DecodeString(option.PublicKey)
   214  			if err != nil {
   215  				return nil, E.Cause(err, "decode peer public key")
   216  			}
   217  			option.PublicKey = hex.EncodeToString(bytes)
   218  		}
   219  		if option.PreSharedKey != "" {
   220  			bytes, err := base64.StdEncoding.DecodeString(option.PreSharedKey)
   221  			if err != nil {
   222  				return nil, E.Cause(err, "decode pre shared key")
   223  			}
   224  			option.PreSharedKey = hex.EncodeToString(bytes)
   225  		}
   226  	}
   227  	outbound.option = option
   228  
   229  	mtu := option.MTU
   230  	if mtu == 0 {
   231  		mtu = 1408
   232  	}
   233  	if len(outbound.localPrefixes) == 0 {
   234  		return nil, E.New("missing local address")
   235  	}
   236  	outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu))
   237  	if err != nil {
   238  		return nil, E.Cause(err, "create WireGuard device")
   239  	}
   240  	outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{
   241  		Verbosef: func(format string, args ...interface{}) {
   242  			log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
   243  		},
   244  		Errorf: func(format string, args ...interface{}) {
   245  			log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
   246  		},
   247  	}, option.Workers)
   248  
   249  	var has6 bool
   250  	for _, address := range outbound.localPrefixes {
   251  		if !address.Addr().Unmap().Is4() {
   252  			has6 = true
   253  			break
   254  		}
   255  	}
   256  
   257  	refP := &refProxyAdapter{}
   258  	outbound.refP = refP
   259  	if option.RemoteDnsResolve && len(option.Dns) > 0 {
   260  		nss, err := dns.ParseNameServer(option.Dns)
   261  		if err != nil {
   262  			return nil, err
   263  		}
   264  		for i := range nss {
   265  			nss[i].ProxyAdapter = refP
   266  		}
   267  		outbound.resolver = dns.NewResolver(dns.Config{
   268  			Main: nss,
   269  			IPv6: has6,
   270  		})
   271  	}
   272  
   273  	return outbound, nil
   274  }
   275  
   276  func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
   277  	if address.Addr.IsValid() {
   278  		return address.AddrPort(), nil
   279  	}
   280  	udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer)
   281  	if err != nil {
   282  		return netip.AddrPort{}, err
   283  	}
   284  	// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
   285  	addrPort := udpAddr.AddrPort()
   286  	return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
   287  }
   288  
   289  func (w *WireGuard) init(ctx context.Context) error {
   290  	if w.initOk.Load() {
   291  		return nil
   292  	}
   293  	w.initMutex.Lock()
   294  	defer w.initMutex.Unlock()
   295  	// double check like sync.Once
   296  	if w.initOk.Load() {
   297  		return nil
   298  	}
   299  	if w.initErr != nil {
   300  		return w.initErr
   301  	}
   302  
   303  	w.bind.ResetReservedForEndpoint()
   304  	ipcConf := "private_key=" + w.option.PrivateKey
   305  	if len(w.option.Peers) > 0 {
   306  		for i, peer := range w.option.Peers {
   307  			destination, err := w.resolve(ctx, peer.Addr())
   308  			if err != nil {
   309  				// !!! do not set initErr here !!!
   310  				// let us can retry domain resolve in next time
   311  				return E.Cause(err, "resolve endpoint domain for peer ", i)
   312  			}
   313  			ipcConf += "\npublic_key=" + peer.PublicKey
   314  			ipcConf += "\nendpoint=" + destination.String()
   315  			if peer.PreSharedKey != "" {
   316  				ipcConf += "\npreshared_key=" + peer.PreSharedKey
   317  			}
   318  			for _, allowedIP := range peer.AllowedIPs {
   319  				ipcConf += "\nallowed_ip=" + allowedIP
   320  			}
   321  			if len(peer.Reserved) > 0 {
   322  				var reserved [3]uint8
   323  				copy(reserved[:], w.option.Reserved)
   324  				w.bind.SetReservedForEndpoint(destination, reserved)
   325  			}
   326  		}
   327  	} else {
   328  		ipcConf += "\npublic_key=" + w.option.PublicKey
   329  		destination, err := w.resolve(ctx, w.connectAddr)
   330  		if err != nil {
   331  			// !!! do not set initErr here !!!
   332  			// let us can retry domain resolve in next time
   333  			return E.Cause(err, "resolve endpoint domain")
   334  		}
   335  		w.bind.SetConnectAddr(destination)
   336  		ipcConf += "\nendpoint=" + destination.String()
   337  		if w.option.PreSharedKey != "" {
   338  			ipcConf += "\npreshared_key=" + w.option.PreSharedKey
   339  		}
   340  		var has4, has6 bool
   341  		for _, address := range w.localPrefixes {
   342  			if address.Addr().Is4() {
   343  				has4 = true
   344  			} else {
   345  				has6 = true
   346  			}
   347  		}
   348  		if has4 {
   349  			ipcConf += "\nallowed_ip=0.0.0.0/0"
   350  		}
   351  		if has6 {
   352  			ipcConf += "\nallowed_ip=::/0"
   353  		}
   354  	}
   355  
   356  	if w.option.PersistentKeepalive != 0 {
   357  		ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", w.option.PersistentKeepalive)
   358  	}
   359  
   360  	if debug.Enabled {
   361  		log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf))
   362  	}
   363  	err := w.device.IpcSet(ipcConf)
   364  	if err != nil {
   365  		w.initErr = E.Cause(err, "setup wireguard")
   366  		return w.initErr
   367  	}
   368  
   369  	err = w.tunDevice.Start()
   370  	if err != nil {
   371  		w.initErr = err
   372  		return w.initErr
   373  	}
   374  
   375  	w.initOk.Store(true)
   376  	return nil
   377  }
   378  
   379  func closeWireGuard(w *WireGuard) {
   380  	if w.device != nil {
   381  		w.device.Close()
   382  	}
   383  	_ = common.Close(w.tunDevice)
   384  	if w.closeCh != nil {
   385  		close(w.closeCh)
   386  	}
   387  }
   388  
   389  func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
   390  	options := w.Base.DialOptions(opts...)
   391  	w.dialer.SetDialer(dialer.NewDialer(options...))
   392  	var conn net.Conn
   393  	if err = w.init(ctx); err != nil {
   394  		return nil, err
   395  	}
   396  	if !metadata.Resolved() || w.resolver != nil {
   397  		r := resolver.DefaultResolver
   398  		if w.resolver != nil {
   399  			w.refP.SetProxyAdapter(w)
   400  			defer w.refP.ClearProxyAdapter()
   401  			r = w.resolver
   402  		}
   403  		options = append(options, dialer.WithResolver(r))
   404  		options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
   405  		conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
   406  	} else {
   407  		conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
   408  	}
   409  	if err != nil {
   410  		return nil, err
   411  	}
   412  	if conn == nil {
   413  		return nil, E.New("conn is nil")
   414  	}
   415  	return NewConn(CN.NewRefConn(conn, w), w), nil
   416  }
   417  
   418  func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
   419  	options := w.Base.DialOptions(opts...)
   420  	w.dialer.SetDialer(dialer.NewDialer(options...))
   421  	var pc net.PacketConn
   422  	if err = w.init(ctx); err != nil {
   423  		return nil, err
   424  	}
   425  	if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
   426  		r := resolver.DefaultResolver
   427  		if w.resolver != nil {
   428  			w.refP.SetProxyAdapter(w)
   429  			defer w.refP.ClearProxyAdapter()
   430  			r = w.resolver
   431  		}
   432  		ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
   433  		if err != nil {
   434  			return nil, errors.New("can't resolve ip")
   435  		}
   436  		metadata.DstIP = ip
   437  	}
   438  	pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
   439  	if err != nil {
   440  		return nil, err
   441  	}
   442  	if pc == nil {
   443  		return nil, E.New("packetConn is nil")
   444  	}
   445  	return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil
   446  }
   447  
   448  // IsL3Protocol implements C.ProxyAdapter
   449  func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
   450  	return true
   451  }
   452  
   453  type refProxyAdapter struct {
   454  	proxyAdapter C.ProxyAdapter
   455  	count        int
   456  	mutex        sync.Mutex
   457  }
   458  
   459  func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) {
   460  	r.mutex.Lock()
   461  	defer r.mutex.Unlock()
   462  	r.proxyAdapter = proxyAdapter
   463  	r.count++
   464  }
   465  
   466  func (r *refProxyAdapter) ClearProxyAdapter() {
   467  	r.mutex.Lock()
   468  	defer r.mutex.Unlock()
   469  	r.count--
   470  	if r.count == 0 {
   471  		r.proxyAdapter = nil
   472  	}
   473  }
   474  
   475  func (r *refProxyAdapter) Name() string {
   476  	if r.proxyAdapter != nil {
   477  		return r.proxyAdapter.Name()
   478  	}
   479  	return ""
   480  }
   481  
   482  func (r *refProxyAdapter) Type() C.AdapterType {
   483  	if r.proxyAdapter != nil {
   484  		return r.proxyAdapter.Type()
   485  	}
   486  	return C.AdapterType(0)
   487  }
   488  
   489  func (r *refProxyAdapter) Addr() string {
   490  	if r.proxyAdapter != nil {
   491  		return r.proxyAdapter.Addr()
   492  	}
   493  	return ""
   494  }
   495  
   496  func (r *refProxyAdapter) SupportUDP() bool {
   497  	if r.proxyAdapter != nil {
   498  		return r.proxyAdapter.SupportUDP()
   499  	}
   500  	return false
   501  }
   502  
   503  func (r *refProxyAdapter) SupportXUDP() bool {
   504  	if r.proxyAdapter != nil {
   505  		return r.proxyAdapter.SupportXUDP()
   506  	}
   507  	return false
   508  }
   509  
   510  func (r *refProxyAdapter) SupportTFO() bool {
   511  	if r.proxyAdapter != nil {
   512  		return r.proxyAdapter.SupportTFO()
   513  	}
   514  	return false
   515  }
   516  
   517  func (r *refProxyAdapter) MarshalJSON() ([]byte, error) {
   518  	if r.proxyAdapter != nil {
   519  		return r.proxyAdapter.MarshalJSON()
   520  	}
   521  	return nil, C.ErrNotSupport
   522  }
   523  
   524  func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
   525  	if r.proxyAdapter != nil {
   526  		return r.proxyAdapter.StreamConnContext(ctx, c, metadata)
   527  	}
   528  	return nil, C.ErrNotSupport
   529  }
   530  
   531  func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
   532  	if r.proxyAdapter != nil {
   533  		return r.proxyAdapter.DialContext(ctx, metadata, opts...)
   534  	}
   535  	return nil, C.ErrNotSupport
   536  }
   537  
   538  func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
   539  	if r.proxyAdapter != nil {
   540  		return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
   541  	}
   542  	return nil, C.ErrNotSupport
   543  }
   544  
   545  func (r *refProxyAdapter) SupportUOT() bool {
   546  	if r.proxyAdapter != nil {
   547  		return r.proxyAdapter.SupportUOT()
   548  	}
   549  	return false
   550  }
   551  
   552  func (r *refProxyAdapter) SupportWithDialer() C.NetWork {
   553  	if r.proxyAdapter != nil {
   554  		return r.proxyAdapter.SupportWithDialer()
   555  	}
   556  	return C.InvalidNet
   557  }
   558  
   559  func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
   560  	if r.proxyAdapter != nil {
   561  		return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
   562  	}
   563  	return nil, C.ErrNotSupport
   564  }
   565  
   566  func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
   567  	if r.proxyAdapter != nil {
   568  		return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
   569  	}
   570  	return nil, C.ErrNotSupport
   571  }
   572  
   573  func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool {
   574  	if r.proxyAdapter != nil {
   575  		return r.proxyAdapter.IsL3Protocol(metadata)
   576  	}
   577  	return false
   578  }
   579  
   580  func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
   581  	if r.proxyAdapter != nil {
   582  		return r.proxyAdapter.Unwrap(metadata, touch)
   583  	}
   584  	return nil
   585  }
   586  
   587  var _ C.ProxyAdapter = (*refProxyAdapter)(nil)