github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/network/multi.go (about) 1 package network 2 3 import ( 4 "context" 5 "net" 6 "net/netip" 7 "time" 8 9 "github.com/sagernet/sing/common" 10 E "github.com/sagernet/sing/common/exceptions" 11 M "github.com/sagernet/sing/common/metadata" 12 ) 13 14 const DefaultFallbackDelay = 300 * time.Millisecond 15 16 func DialSerial(ctx context.Context, dialer Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { 17 if parallelDialer, isParallel := dialer.(ParallelDialer); isParallel { 18 return parallelDialer.DialParallel(ctx, network, destination, destinationAddresses) 19 } 20 var conn net.Conn 21 var err error 22 var connErrors []error 23 for _, address := range destinationAddresses { 24 conn, err = dialer.DialContext(ctx, network, M.SocksaddrFrom(address, destination.Port)) 25 if err != nil { 26 connErrors = append(connErrors, err) 27 continue 28 } 29 return conn, nil 30 } 31 return nil, E.Errors(connErrors...) 32 } 33 34 func ListenSerial(ctx context.Context, dialer Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.PacketConn, netip.Addr, error) { 35 var conn net.PacketConn 36 var err error 37 var connErrors []error 38 for _, address := range destinationAddresses { 39 conn, err = dialer.ListenPacket(ctx, M.SocksaddrFrom(address, destination.Port)) 40 if err != nil { 41 connErrors = append(connErrors, err) 42 continue 43 } 44 return conn, address, nil 45 } 46 return nil, netip.Addr{}, E.Errors(connErrors...) 47 } 48 49 func DialParallel(ctx context.Context, dialer Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, fallbackDelay time.Duration) (net.Conn, error) { 50 // kanged form net.Dial 51 52 if fallbackDelay == 0 { 53 fallbackDelay = DefaultFallbackDelay 54 } 55 56 returned := make(chan struct{}) 57 defer close(returned) 58 59 addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool { 60 return address.Is4() || address.Is4In6() 61 }) 62 addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool { 63 return address.Is6() && !address.Is4In6() 64 }) 65 if len(addresses4) == 0 || len(addresses6) == 0 { 66 return DialSerial(ctx, dialer, network, destination, destinationAddresses) 67 } 68 var primaries, fallbacks []netip.Addr 69 if preferIPv6 { 70 primaries = addresses6 71 fallbacks = addresses4 72 } else { 73 primaries = addresses4 74 fallbacks = addresses6 75 } 76 type dialResult struct { 77 net.Conn 78 error 79 primary bool 80 done bool 81 } 82 results := make(chan dialResult) // unbuffered 83 startRacer := func(ctx context.Context, primary bool) { 84 ras := primaries 85 if !primary { 86 ras = fallbacks 87 } 88 c, err := DialSerial(ctx, dialer, network, destination, ras) 89 select { 90 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: 91 case <-returned: 92 if c != nil { 93 c.Close() 94 } 95 } 96 } 97 var primary, fallback dialResult 98 primaryCtx, primaryCancel := context.WithCancel(ctx) 99 defer primaryCancel() 100 go startRacer(primaryCtx, true) 101 fallbackTimer := time.NewTimer(fallbackDelay) 102 defer fallbackTimer.Stop() 103 for { 104 select { 105 case <-fallbackTimer.C: 106 fallbackCtx, fallbackCancel := context.WithCancel(ctx) 107 defer fallbackCancel() 108 go startRacer(fallbackCtx, false) 109 110 case res := <-results: 111 if res.error == nil { 112 return res.Conn, nil 113 } 114 if res.primary { 115 primary = res 116 } else { 117 fallback = res 118 } 119 if primary.done && fallback.done { 120 return nil, primary.error 121 } 122 if res.primary && fallbackTimer.Stop() { 123 fallbackTimer.Reset(0) 124 } 125 } 126 } 127 }