github.com/metacubex/mihomo@v1.18.5/component/dialer/dialer.go (about)

     1  package dialer
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/netip"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/metacubex/mihomo/component/resolver"
    16  	"github.com/metacubex/mihomo/constant/features"
    17  	"github.com/metacubex/mihomo/log"
    18  )
    19  
    20  const (
    21  	DefaultTCPTimeout = 5 * time.Second
    22  	DefaultUDPTimeout = DefaultTCPTimeout
    23  )
    24  
    25  type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
    26  
    27  var (
    28  	dialMux                      sync.Mutex
    29  	IP4PEnable                   bool
    30  	actualSingleStackDialContext = serialSingleStackDialContext
    31  	actualDualStackDialContext   = serialDualStackDialContext
    32  	tcpConcurrent                = false
    33  	fallbackTimeout              = 300 * time.Millisecond
    34  )
    35  
    36  func applyOptions(options ...Option) *option {
    37  	opt := &option{
    38  		interfaceName: DefaultInterface.Load(),
    39  		routingMark:   int(DefaultRoutingMark.Load()),
    40  	}
    41  
    42  	for _, o := range DefaultOptions {
    43  		o(opt)
    44  	}
    45  
    46  	for _, o := range options {
    47  		o(opt)
    48  	}
    49  
    50  	return opt
    51  }
    52  
    53  func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
    54  	opt := applyOptions(options...)
    55  
    56  	if opt.network == 4 || opt.network == 6 {
    57  		if strings.Contains(network, "tcp") {
    58  			network = "tcp"
    59  		} else {
    60  			network = "udp"
    61  		}
    62  
    63  		network = fmt.Sprintf("%s%d", network, opt.network)
    64  	}
    65  
    66  	ips, port, err := parseAddr(ctx, network, address, opt.resolver)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	switch network {
    72  	case "tcp4", "tcp6", "udp4", "udp6":
    73  		return actualSingleStackDialContext(ctx, network, ips, port, opt)
    74  	case "tcp", "udp":
    75  		return actualDualStackDialContext(ctx, network, ips, port, opt)
    76  	default:
    77  		return nil, ErrorInvalidedNetworkStack
    78  	}
    79  }
    80  
    81  func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
    82  	if features.CMFA && DefaultSocketHook != nil {
    83  		return listenPacketHooked(ctx, network, address)
    84  	}
    85  
    86  	cfg := applyOptions(options...)
    87  
    88  	lc := &net.ListenConfig{}
    89  	if cfg.interfaceName != "" {
    90  		bind := bindIfaceToListenConfig
    91  		if cfg.fallbackBind {
    92  			bind = fallbackBindIfaceToListenConfig
    93  		}
    94  		addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort)
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  		address = addr
    99  	}
   100  	if cfg.addrReuse {
   101  		addrReuseToListenConfig(lc)
   102  	}
   103  	if cfg.routingMark != 0 {
   104  		bindMarkToListenConfig(cfg.routingMark, lc, network, address)
   105  	}
   106  
   107  	return lc.ListenPacket(ctx, network, address)
   108  }
   109  
   110  func SetTcpConcurrent(concurrent bool) {
   111  	dialMux.Lock()
   112  	defer dialMux.Unlock()
   113  	tcpConcurrent = concurrent
   114  	if concurrent {
   115  		actualSingleStackDialContext = concurrentSingleStackDialContext
   116  		actualDualStackDialContext = concurrentDualStackDialContext
   117  	} else {
   118  		actualSingleStackDialContext = serialSingleStackDialContext
   119  		actualDualStackDialContext = serialDualStackDialContext
   120  	}
   121  }
   122  
   123  func GetTcpConcurrent() bool {
   124  	dialMux.Lock()
   125  	defer dialMux.Unlock()
   126  	return tcpConcurrent
   127  }
   128  
   129  func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
   130  	if features.CMFA && DefaultSocketHook != nil {
   131  		return dialContextHooked(ctx, network, destination, port)
   132  	}
   133  
   134  	var address string
   135  	if IP4PEnable {
   136  		destination, port = lookupIP4P(destination, port)
   137  	}
   138  	address = net.JoinHostPort(destination.String(), port)
   139  
   140  	netDialer := opt.netDialer
   141  	switch netDialer.(type) {
   142  	case nil:
   143  		netDialer = &net.Dialer{}
   144  	case *net.Dialer:
   145  		_netDialer := *netDialer.(*net.Dialer)
   146  		netDialer = &_netDialer // make a copy
   147  	default:
   148  		return netDialer.DialContext(ctx, network, address)
   149  	}
   150  
   151  	dialer := netDialer.(*net.Dialer)
   152  	if opt.interfaceName != "" {
   153  		bind := bindIfaceToDialer
   154  		if opt.fallbackBind {
   155  			bind = fallbackBindIfaceToDialer
   156  		}
   157  		if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
   158  			return nil, err
   159  		}
   160  	}
   161  	if opt.routingMark != 0 {
   162  		bindMarkToDialer(opt.routingMark, dialer, network, destination)
   163  	}
   164  	if opt.mpTcp {
   165  		setMultiPathTCP(dialer)
   166  	}
   167  	if opt.tfo && !DisableTFO {
   168  		return dialTFO(ctx, *dialer, network, address)
   169  	}
   170  	return dialer.DialContext(ctx, network, address)
   171  }
   172  
   173  func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   174  	return serialDialContext(ctx, network, ips, port, opt)
   175  }
   176  
   177  func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   178  	return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
   179  }
   180  
   181  func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   182  	return parallelDialContext(ctx, network, ips, port, opt)
   183  }
   184  
   185  func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   186  	if opt.prefer != 4 && opt.prefer != 6 {
   187  		return parallelDialContext(ctx, network, ips, port, opt)
   188  	}
   189  	return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
   190  }
   191  
   192  func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   193  	ipv4s, ipv6s := resolver.SortationAddr(ips)
   194  	if len(ipv4s) == 0 && len(ipv6s) == 0 {
   195  		return nil, ErrorNoIpAddress
   196  	}
   197  
   198  	preferIPVersion := opt.prefer
   199  	fallbackTicker := time.NewTicker(fallbackTimeout)
   200  	defer fallbackTicker.Stop()
   201  
   202  	results := make(chan dialResult)
   203  	returned := make(chan struct{})
   204  	defer close(returned)
   205  
   206  	var wg sync.WaitGroup
   207  
   208  	racer := func(ips []netip.Addr, isPrimary bool) {
   209  		defer wg.Done()
   210  		result := dialResult{isPrimary: isPrimary}
   211  		defer func() {
   212  			select {
   213  			case results <- result:
   214  			case <-returned:
   215  				if result.Conn != nil && result.error == nil {
   216  					_ = result.Conn.Close()
   217  				}
   218  			}
   219  		}()
   220  		result.Conn, result.error = dialFn(ctx, network, ips, port, opt)
   221  	}
   222  
   223  	if len(ipv4s) != 0 {
   224  		wg.Add(1)
   225  		go racer(ipv4s, preferIPVersion != 6)
   226  	}
   227  
   228  	if len(ipv6s) != 0 {
   229  		wg.Add(1)
   230  		go racer(ipv6s, preferIPVersion != 4)
   231  	}
   232  
   233  	go func() {
   234  		wg.Wait()
   235  		close(results)
   236  	}()
   237  
   238  	var fallback dialResult
   239  	var errs []error
   240  
   241  loop:
   242  	for {
   243  		select {
   244  		case <-fallbackTicker.C:
   245  			if fallback.error == nil && fallback.Conn != nil {
   246  				return fallback.Conn, nil
   247  			}
   248  		case res, ok := <-results:
   249  			if !ok {
   250  				break loop
   251  			}
   252  			if res.error == nil {
   253  				if res.isPrimary {
   254  					return res.Conn, nil
   255  				}
   256  				fallback = res
   257  			} else {
   258  				if res.isPrimary {
   259  					errs = append([]error{fmt.Errorf("connect failed: %w", res.error)}, errs...)
   260  				} else {
   261  					errs = append(errs, fmt.Errorf("connect failed: %w", res.error))
   262  				}
   263  			}
   264  		}
   265  	}
   266  
   267  	if fallback.error == nil && fallback.Conn != nil {
   268  		return fallback.Conn, nil
   269  	}
   270  	return nil, errors.Join(errs...)
   271  }
   272  
   273  func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   274  	if len(ips) == 0 {
   275  		return nil, ErrorNoIpAddress
   276  	}
   277  	results := make(chan dialResult)
   278  	returned := make(chan struct{})
   279  	defer close(returned)
   280  	racer := func(ctx context.Context, ip netip.Addr) {
   281  		result := dialResult{isPrimary: true, ip: ip}
   282  		defer func() {
   283  			select {
   284  			case results <- result:
   285  			case <-returned:
   286  				if result.Conn != nil && result.error == nil {
   287  					_ = result.Conn.Close()
   288  				}
   289  			}
   290  		}()
   291  		result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
   292  	}
   293  
   294  	for _, ip := range ips {
   295  		go racer(ctx, ip)
   296  	}
   297  	var errs []error
   298  	for i := 0; i < len(ips); i++ {
   299  		res := <-results
   300  		if res.error == nil {
   301  			return res.Conn, nil
   302  		}
   303  		errs = append(errs, res.error)
   304  	}
   305  
   306  	if len(errs) > 0 {
   307  		return nil, errors.Join(errs...)
   308  	}
   309  	return nil, os.ErrDeadlineExceeded
   310  }
   311  
   312  func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
   313  	if len(ips) == 0 {
   314  		return nil, ErrorNoIpAddress
   315  	}
   316  	var errs []error
   317  	for _, ip := range ips {
   318  		if conn, err := dialContext(ctx, network, ip, port, opt); err == nil {
   319  			return conn, nil
   320  		} else {
   321  			errs = append(errs, err)
   322  		}
   323  	}
   324  	return nil, errors.Join(errs...)
   325  }
   326  
   327  type dialResult struct {
   328  	ip netip.Addr
   329  	net.Conn
   330  	error
   331  	isPrimary bool
   332  }
   333  
   334  func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) {
   335  	host, port, err := net.SplitHostPort(address)
   336  	if err != nil {
   337  		return nil, "-1", err
   338  	}
   339  
   340  	var ips []netip.Addr
   341  	switch network {
   342  	case "tcp4", "udp4":
   343  		if preferResolver == nil {
   344  			ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
   345  		} else {
   346  			ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver)
   347  		}
   348  	case "tcp6", "udp6":
   349  		if preferResolver == nil {
   350  			ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
   351  		} else {
   352  			ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver)
   353  		}
   354  	default:
   355  		if preferResolver == nil {
   356  			ips, err = resolver.LookupIPProxyServerHost(ctx, host)
   357  		} else {
   358  			ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver)
   359  		}
   360  	}
   361  	if err != nil {
   362  		return nil, "-1", fmt.Errorf("dns resolve failed: %w", err)
   363  	}
   364  	for i, ip := range ips {
   365  		if ip.Is4In6() {
   366  			ips[i] = ip.Unmap()
   367  		}
   368  	}
   369  	return ips, port, nil
   370  }
   371  
   372  type Dialer struct {
   373  	Opt option
   374  }
   375  
   376  func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   377  	return DialContext(ctx, network, address, WithOption(d.Opt))
   378  }
   379  
   380  func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
   381  	opt := WithOption(d.Opt)
   382  	if rAddrPort.Addr().Unmap().IsLoopback() {
   383  		// avoid "The requested address is not valid in its context."
   384  		opt = WithInterface("")
   385  	}
   386  	return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, opt)
   387  }
   388  
   389  func NewDialer(options ...Option) Dialer {
   390  	opt := applyOptions(options...)
   391  	return Dialer{Opt: *opt}
   392  }
   393  
   394  func GetIP4PEnable(enableIP4PConvert bool) {
   395  	IP4PEnable = enableIP4PConvert
   396  }
   397  
   398  // kanged from https://github.com/heiher/frp/blob/ip4p/client/ip4p.go
   399  
   400  func lookupIP4P(addr netip.Addr, port string) (netip.Addr, string) {
   401  	ip := addr.AsSlice()
   402  	if ip[0] == 0x20 && ip[1] == 0x01 &&
   403  		ip[2] == 0x00 && ip[3] == 0x00 {
   404  		addr = netip.AddrFrom4([4]byte{ip[12], ip[13], ip[14], ip[15]})
   405  		port = strconv.Itoa(int(ip[10])<<8 + int(ip[11]))
   406  		log.Debugln("Convert IP4P address %s to %s", ip, net.JoinHostPort(addr.String(), port))
   407  		return addr, port
   408  	}
   409  	return addr, port
   410  }