github.com/metacubex/mihomo@v1.18.5/component/dialer/dialer.go (about) 1 package dialer 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "net/netip" 9 "os" 10 "strconv" 11 "strings" 12 "sync" 13 "time" 14 15 "github.com/metacubex/mihomo/component/resolver" 16 "github.com/metacubex/mihomo/constant/features" 17 "github.com/metacubex/mihomo/log" 18 ) 19 20 const ( 21 DefaultTCPTimeout = 5 * time.Second 22 DefaultUDPTimeout = DefaultTCPTimeout 23 ) 24 25 type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) 26 27 var ( 28 dialMux sync.Mutex 29 IP4PEnable bool 30 actualSingleStackDialContext = serialSingleStackDialContext 31 actualDualStackDialContext = serialDualStackDialContext 32 tcpConcurrent = false 33 fallbackTimeout = 300 * time.Millisecond 34 ) 35 36 func applyOptions(options ...Option) *option { 37 opt := &option{ 38 interfaceName: DefaultInterface.Load(), 39 routingMark: int(DefaultRoutingMark.Load()), 40 } 41 42 for _, o := range DefaultOptions { 43 o(opt) 44 } 45 46 for _, o := range options { 47 o(opt) 48 } 49 50 return opt 51 } 52 53 func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { 54 opt := applyOptions(options...) 55 56 if opt.network == 4 || opt.network == 6 { 57 if strings.Contains(network, "tcp") { 58 network = "tcp" 59 } else { 60 network = "udp" 61 } 62 63 network = fmt.Sprintf("%s%d", network, opt.network) 64 } 65 66 ips, port, err := parseAddr(ctx, network, address, opt.resolver) 67 if err != nil { 68 return nil, err 69 } 70 71 switch network { 72 case "tcp4", "tcp6", "udp4", "udp6": 73 return actualSingleStackDialContext(ctx, network, ips, port, opt) 74 case "tcp", "udp": 75 return actualDualStackDialContext(ctx, network, ips, port, opt) 76 default: 77 return nil, ErrorInvalidedNetworkStack 78 } 79 } 80 81 func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) { 82 if features.CMFA && DefaultSocketHook != nil { 83 return listenPacketHooked(ctx, network, address) 84 } 85 86 cfg := applyOptions(options...) 87 88 lc := &net.ListenConfig{} 89 if cfg.interfaceName != "" { 90 bind := bindIfaceToListenConfig 91 if cfg.fallbackBind { 92 bind = fallbackBindIfaceToListenConfig 93 } 94 addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort) 95 if err != nil { 96 return nil, err 97 } 98 address = addr 99 } 100 if cfg.addrReuse { 101 addrReuseToListenConfig(lc) 102 } 103 if cfg.routingMark != 0 { 104 bindMarkToListenConfig(cfg.routingMark, lc, network, address) 105 } 106 107 return lc.ListenPacket(ctx, network, address) 108 } 109 110 func SetTcpConcurrent(concurrent bool) { 111 dialMux.Lock() 112 defer dialMux.Unlock() 113 tcpConcurrent = concurrent 114 if concurrent { 115 actualSingleStackDialContext = concurrentSingleStackDialContext 116 actualDualStackDialContext = concurrentDualStackDialContext 117 } else { 118 actualSingleStackDialContext = serialSingleStackDialContext 119 actualDualStackDialContext = serialDualStackDialContext 120 } 121 } 122 123 func GetTcpConcurrent() bool { 124 dialMux.Lock() 125 defer dialMux.Unlock() 126 return tcpConcurrent 127 } 128 129 func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) { 130 if features.CMFA && DefaultSocketHook != nil { 131 return dialContextHooked(ctx, network, destination, port) 132 } 133 134 var address string 135 if IP4PEnable { 136 destination, port = lookupIP4P(destination, port) 137 } 138 address = net.JoinHostPort(destination.String(), port) 139 140 netDialer := opt.netDialer 141 switch netDialer.(type) { 142 case nil: 143 netDialer = &net.Dialer{} 144 case *net.Dialer: 145 _netDialer := *netDialer.(*net.Dialer) 146 netDialer = &_netDialer // make a copy 147 default: 148 return netDialer.DialContext(ctx, network, address) 149 } 150 151 dialer := netDialer.(*net.Dialer) 152 if opt.interfaceName != "" { 153 bind := bindIfaceToDialer 154 if opt.fallbackBind { 155 bind = fallbackBindIfaceToDialer 156 } 157 if err := bind(opt.interfaceName, dialer, network, destination); err != nil { 158 return nil, err 159 } 160 } 161 if opt.routingMark != 0 { 162 bindMarkToDialer(opt.routingMark, dialer, network, destination) 163 } 164 if opt.mpTcp { 165 setMultiPathTCP(dialer) 166 } 167 if opt.tfo && !DisableTFO { 168 return dialTFO(ctx, *dialer, network, address) 169 } 170 return dialer.DialContext(ctx, network, address) 171 } 172 173 func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 174 return serialDialContext(ctx, network, ips, port, opt) 175 } 176 177 func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 178 return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt) 179 } 180 181 func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 182 return parallelDialContext(ctx, network, ips, port, opt) 183 } 184 185 func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 186 if opt.prefer != 4 && opt.prefer != 6 { 187 return parallelDialContext(ctx, network, ips, port, opt) 188 } 189 return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt) 190 } 191 192 func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 193 ipv4s, ipv6s := resolver.SortationAddr(ips) 194 if len(ipv4s) == 0 && len(ipv6s) == 0 { 195 return nil, ErrorNoIpAddress 196 } 197 198 preferIPVersion := opt.prefer 199 fallbackTicker := time.NewTicker(fallbackTimeout) 200 defer fallbackTicker.Stop() 201 202 results := make(chan dialResult) 203 returned := make(chan struct{}) 204 defer close(returned) 205 206 var wg sync.WaitGroup 207 208 racer := func(ips []netip.Addr, isPrimary bool) { 209 defer wg.Done() 210 result := dialResult{isPrimary: isPrimary} 211 defer func() { 212 select { 213 case results <- result: 214 case <-returned: 215 if result.Conn != nil && result.error == nil { 216 _ = result.Conn.Close() 217 } 218 } 219 }() 220 result.Conn, result.error = dialFn(ctx, network, ips, port, opt) 221 } 222 223 if len(ipv4s) != 0 { 224 wg.Add(1) 225 go racer(ipv4s, preferIPVersion != 6) 226 } 227 228 if len(ipv6s) != 0 { 229 wg.Add(1) 230 go racer(ipv6s, preferIPVersion != 4) 231 } 232 233 go func() { 234 wg.Wait() 235 close(results) 236 }() 237 238 var fallback dialResult 239 var errs []error 240 241 loop: 242 for { 243 select { 244 case <-fallbackTicker.C: 245 if fallback.error == nil && fallback.Conn != nil { 246 return fallback.Conn, nil 247 } 248 case res, ok := <-results: 249 if !ok { 250 break loop 251 } 252 if res.error == nil { 253 if res.isPrimary { 254 return res.Conn, nil 255 } 256 fallback = res 257 } else { 258 if res.isPrimary { 259 errs = append([]error{fmt.Errorf("connect failed: %w", res.error)}, errs...) 260 } else { 261 errs = append(errs, fmt.Errorf("connect failed: %w", res.error)) 262 } 263 } 264 } 265 } 266 267 if fallback.error == nil && fallback.Conn != nil { 268 return fallback.Conn, nil 269 } 270 return nil, errors.Join(errs...) 271 } 272 273 func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 274 if len(ips) == 0 { 275 return nil, ErrorNoIpAddress 276 } 277 results := make(chan dialResult) 278 returned := make(chan struct{}) 279 defer close(returned) 280 racer := func(ctx context.Context, ip netip.Addr) { 281 result := dialResult{isPrimary: true, ip: ip} 282 defer func() { 283 select { 284 case results <- result: 285 case <-returned: 286 if result.Conn != nil && result.error == nil { 287 _ = result.Conn.Close() 288 } 289 } 290 }() 291 result.Conn, result.error = dialContext(ctx, network, ip, port, opt) 292 } 293 294 for _, ip := range ips { 295 go racer(ctx, ip) 296 } 297 var errs []error 298 for i := 0; i < len(ips); i++ { 299 res := <-results 300 if res.error == nil { 301 return res.Conn, nil 302 } 303 errs = append(errs, res.error) 304 } 305 306 if len(errs) > 0 { 307 return nil, errors.Join(errs...) 308 } 309 return nil, os.ErrDeadlineExceeded 310 } 311 312 func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) { 313 if len(ips) == 0 { 314 return nil, ErrorNoIpAddress 315 } 316 var errs []error 317 for _, ip := range ips { 318 if conn, err := dialContext(ctx, network, ip, port, opt); err == nil { 319 return conn, nil 320 } else { 321 errs = append(errs, err) 322 } 323 } 324 return nil, errors.Join(errs...) 325 } 326 327 type dialResult struct { 328 ip netip.Addr 329 net.Conn 330 error 331 isPrimary bool 332 } 333 334 func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) { 335 host, port, err := net.SplitHostPort(address) 336 if err != nil { 337 return nil, "-1", err 338 } 339 340 var ips []netip.Addr 341 switch network { 342 case "tcp4", "udp4": 343 if preferResolver == nil { 344 ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host) 345 } else { 346 ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver) 347 } 348 case "tcp6", "udp6": 349 if preferResolver == nil { 350 ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host) 351 } else { 352 ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver) 353 } 354 default: 355 if preferResolver == nil { 356 ips, err = resolver.LookupIPProxyServerHost(ctx, host) 357 } else { 358 ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver) 359 } 360 } 361 if err != nil { 362 return nil, "-1", fmt.Errorf("dns resolve failed: %w", err) 363 } 364 for i, ip := range ips { 365 if ip.Is4In6() { 366 ips[i] = ip.Unmap() 367 } 368 } 369 return ips, port, nil 370 } 371 372 type Dialer struct { 373 Opt option 374 } 375 376 func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 377 return DialContext(ctx, network, address, WithOption(d.Opt)) 378 } 379 380 func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { 381 opt := WithOption(d.Opt) 382 if rAddrPort.Addr().Unmap().IsLoopback() { 383 // avoid "The requested address is not valid in its context." 384 opt = WithInterface("") 385 } 386 return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, opt) 387 } 388 389 func NewDialer(options ...Option) Dialer { 390 opt := applyOptions(options...) 391 return Dialer{Opt: *opt} 392 } 393 394 func GetIP4PEnable(enableIP4PConvert bool) { 395 IP4PEnable = enableIP4PConvert 396 } 397 398 // kanged from https://github.com/heiher/frp/blob/ip4p/client/ip4p.go 399 400 func lookupIP4P(addr netip.Addr, port string) (netip.Addr, string) { 401 ip := addr.AsSlice() 402 if ip[0] == 0x20 && ip[1] == 0x01 && 403 ip[2] == 0x00 && ip[3] == 0x00 { 404 addr = netip.AddrFrom4([4]byte{ip[12], ip[13], ip[14], ip[15]}) 405 port = strconv.Itoa(int(ip[10])<<8 + int(ip[11])) 406 log.Debugln("Convert IP4P address %s to %s", ip, net.JoinHostPort(addr.String(), port)) 407 return addr, port 408 } 409 return addr, port 410 }