github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/dialer/happyeyeballs.go (about) 1 package dialer 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "time" 8 ) 9 10 // DialHappyEyeballs is a function that implements Happy Eyeballs algorithm for IPv4 and IPv6 addresses. 11 // It divides given TCP addresses into primaries and fallbacks and then calls DialParallel function. 12 // 13 // It takes a context and a slice of TCP addresses as input and returns a net.Conn and an error. 14 // 15 // https://www.rfc-editor.org/rfc/rfc8305 16 func DialHappyEyeballs(ctx context.Context, ips []*net.TCPAddr) (net.Conn, error) { 17 // Divide TCP addresses into primaries and fallbacks based on their IP version. 18 primaries := []*net.TCPAddr{} // TCP addresses with IPv4 version 19 fallback := []*net.TCPAddr{} // TCP addresses with IPv6 version 20 21 for _, ip := range ips { 22 if ip.IP.To4() != nil { 23 fallback = append(fallback, ip) 24 } else { 25 primaries = append(primaries, ip) 26 } 27 } 28 29 // If there are no primaries, use fallbacks as primaries. 30 if len(primaries) == 0 { 31 if len(fallback) == 0 { 32 return nil, errors.New("no addresses") 33 } 34 35 primaries = fallback 36 } 37 38 // Call DialParallel function with primaries and fallbacks. 39 return DialParallel(ctx, primaries, fallback) 40 } 41 42 // https://github.com/golang/go/blob/315b6ae682a2a4e7718924a45b8b311a0fe10043/src/net/dial.go#L534 43 // 44 // dialParallel races two copies of dialSerial, giving the first a 45 // head start. It returns the first established connection and 46 // closes the others. Otherwise it returns an error from the first 47 // primary address. 48 func DialParallel(ctx context.Context, primaries []*net.TCPAddr, fallbacks []*net.TCPAddr) (net.Conn, error) { 49 if len(fallbacks) == 0 { 50 return DialSerial(ctx, primaries) 51 } 52 53 returned := make(chan struct{}) 54 defer close(returned) 55 56 type dialResult struct { 57 net.Conn 58 error 59 primary bool 60 done bool 61 } 62 results := make(chan dialResult) // unbuffered 63 64 startRacer := func(ctx context.Context, primary bool) { 65 ras := primaries 66 if !primary { 67 ras = fallbacks 68 } 69 c, err := DialSerial(ctx, ras) 70 select { 71 case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: 72 case <-returned: 73 if c != nil { 74 c.Close() 75 } 76 } 77 } 78 79 var primary, fallback dialResult 80 81 // Start the main racer. 82 primaryCtx, primaryCancel := context.WithCancel(ctx) 83 defer primaryCancel() 84 go startRacer(primaryCtx, true) 85 86 // Start the timer for the fallback racer. 87 fallbackTimer := time.NewTimer(time.Millisecond * 300) 88 defer fallbackTimer.Stop() 89 90 for { 91 select { 92 case <-fallbackTimer.C: 93 fallbackCtx, fallbackCancel := context.WithCancel(ctx) 94 defer fallbackCancel() 95 go startRacer(fallbackCtx, false) 96 97 case res := <-results: 98 if res.error == nil { 99 return res.Conn, nil 100 } 101 if res.primary { 102 primary = res 103 } else { 104 fallback = res 105 } 106 if primary.done && fallback.done { 107 return nil, primary.error 108 } 109 if res.primary && fallbackTimer.Stop() { 110 // If we were able to stop the timer, that means it 111 // was running (hadn't yet started the fallback), but 112 // we just got an error on the primary path, so start 113 // the fallback immediately (in 0 nanoseconds). 114 fallbackTimer.Reset(0) 115 } 116 } 117 } 118 } 119 120 // DialSerial connects to a list of addresses in sequence, returning 121 // either the first successful connection, or the first error. 122 func DialSerial(ctx context.Context, ras []*net.TCPAddr) (net.Conn, error) { 123 var firstErr error // The error from the first address is most relevant. 124 125 for i, ra := range ras { 126 select { 127 case <-ctx.Done(): 128 return nil, ctx.Err() 129 default: 130 } 131 132 dialCtx, cancel, err := PartialDeadlineCtx(ctx, len(ras)-i) 133 if err != nil { 134 // Ran out of time. 135 if firstErr == nil { 136 firstErr = err 137 } 138 break 139 } 140 defer cancel() 141 142 c, err := dialSingle(dialCtx, ra) 143 if err == nil { 144 return c, nil 145 } 146 if firstErr == nil { 147 firstErr = err 148 } 149 } 150 151 if firstErr == nil { 152 firstErr = errors.New("errMissingAddress") 153 } 154 return nil, firstErr 155 } 156 157 func PartialDeadlineCtx(ctx context.Context, addrsRemaining int) (context.Context, context.CancelFunc, error) { 158 dialCtx := ctx 159 cancel := func() {} 160 if deadline, hasDeadline := ctx.Deadline(); hasDeadline { 161 partialDeadline, err := PartialDeadline(time.Now(), deadline, addrsRemaining) 162 if err != nil { 163 // Ran out of time. 164 return dialCtx, cancel, err 165 } 166 if partialDeadline.Before(deadline) { 167 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) 168 } 169 } 170 171 return dialCtx, cancel, nil 172 } 173 174 func dialSingle(ctx context.Context, ips *net.TCPAddr) (net.Conn, error) { 175 return DialContext(ctx, "tcp", ips.String()) 176 } 177 178 // PartialDeadline returns the deadline to use for a single address, 179 // when multiple addresses are pending. 180 func PartialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { 181 if deadline.IsZero() { 182 return deadline, nil 183 } 184 timeRemaining := deadline.Sub(now) 185 if timeRemaining <= 0 { 186 return time.Time{}, errors.New("errTimeout") 187 } 188 // Tentatively allocate equal time to each remaining address. 189 timeout := timeRemaining / time.Duration(addrsRemaining) 190 // If the time per address is too short, steal from the end of the list. 191 const saneMinimum = 2 * time.Second 192 if timeout < saneMinimum { 193 if timeRemaining < saneMinimum { 194 timeout = timeRemaining 195 } else { 196 timeout = saneMinimum 197 } 198 } 199 return now.Add(timeout), nil 200 }