github.com/metacubex/tfo-go@v0.0.0-20240228025757-be1269474a66/tfo_supported.go (about) 1 //go:build darwin || freebsd || linux || windows 2 3 package tfo 4 5 import ( 6 "context" 7 "net" 8 "os" 9 "syscall" 10 "time" 11 _ "unsafe" 12 ) 13 14 const comptimeNoTFO = false 15 16 const ( 17 defaultTCPKeepAlive = 15 * time.Second 18 defaultFallbackDelay = 300 * time.Millisecond 19 ) 20 21 // Boolean to int. 22 func boolint(b bool) int { 23 if b { 24 return 1 25 } 26 return 0 27 } 28 29 // A sockaddr represents a TCP, UDP, IP or Unix network endpoint 30 // address that can be converted into a syscall.Sockaddr. 31 // 32 // Copied from src/net/sockaddr_posix.go 33 type sockaddr interface { 34 net.Addr 35 36 // family returns the platform-dependent address family 37 // identifier. 38 family() int 39 40 // isWildcard reports whether the address is a wildcard 41 // address. 42 isWildcard() bool 43 44 // sockaddr returns the address converted into a syscall 45 // sockaddr type that implements syscall.Sockaddr 46 // interface. It returns a nil interface when the address is 47 // nil. 48 sockaddr(family int) (syscall.Sockaddr, error) 49 50 // toLocal maps the zero address to a local system address (127.0.0.1 or ::1) 51 toLocal(net string) sockaddr 52 } 53 54 type tcpSockaddr net.TCPAddr 55 56 func (a *tcpSockaddr) Network() string { 57 return "tcp" 58 } 59 60 func (a *tcpSockaddr) String() string { 61 return (*net.TCPAddr)(a).String() 62 } 63 64 // Copied from src/net/tcpsock_posix.go 65 func (a *tcpSockaddr) family() int { 66 if a == nil || len(a.IP) <= net.IPv4len { 67 return syscall.AF_INET 68 } 69 if a.IP.To4() != nil { 70 return syscall.AF_INET 71 } 72 return syscall.AF_INET6 73 } 74 75 // Copied from src/net/tcpsock_posix.go 76 func (a *tcpSockaddr) isWildcard() bool { 77 if a == nil || a.IP == nil { 78 return true 79 } 80 return a.IP.IsUnspecified() 81 } 82 83 //go:linkname ipToSockaddr net.ipToSockaddr 84 func ipToSockaddr(family int, ip net.IP, port int, zone string) (syscall.Sockaddr, error) 85 86 // Copied from src/net/tcpsock_posix.go 87 func (a *tcpSockaddr) sockaddr(family int) (syscall.Sockaddr, error) { 88 if a == nil { 89 return nil, nil 90 } 91 return ipToSockaddr(family, a.IP, a.Port, a.Zone) 92 } 93 94 //go:linkname loopbackIP net.loopbackIP 95 func loopbackIP(net string) net.IP 96 97 // Modified from src/net/tcpsock_posix.go 98 func (a *tcpSockaddr) toLocal(net string) sockaddr { 99 la := *a 100 la.IP = loopbackIP(net) 101 return &la 102 } 103 104 //go:linkname favoriteAddrFamily net.favoriteAddrFamily 105 func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) 106 107 func (d *Dialer) dialTFOFromSocket(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) { 108 if ctx == nil { 109 panic("nil context") 110 } 111 deadline := d.deadline(ctx, time.Now()) 112 if !deadline.IsZero() { 113 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { 114 subCtx, cancel := context.WithDeadline(ctx, deadline) 115 defer cancel() 116 ctx = subCtx 117 } 118 } 119 if oldCancel := d.Cancel; oldCancel != nil { 120 subCtx, cancel := context.WithCancel(ctx) 121 defer cancel() 122 go func() { 123 select { 124 case <-oldCancel: 125 cancel() 126 case <-subCtx.Done(): 127 } 128 }() 129 ctx = subCtx 130 } 131 132 var laddr *net.TCPAddr 133 if d.LocalAddr != nil { 134 la, ok := d.LocalAddr.(*net.TCPAddr) 135 if !ok { 136 return nil, &net.OpError{ 137 Op: "dial", 138 Net: network, 139 Source: nil, 140 Addr: nil, 141 Err: &net.AddrError{ 142 Err: "mismatched local address type", 143 Addr: d.LocalAddr.String(), 144 }, 145 } 146 } 147 laddr = la 148 } 149 150 host, port, err := net.SplitHostPort(address) 151 if err != nil { 152 return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} 153 } 154 portNum, err := d.Resolver.LookupPort(ctx, network, port) 155 if err != nil { 156 return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} 157 } 158 ipaddrs, err := d.Resolver.LookupIPAddr(ctx, host) 159 if err != nil { 160 return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} 161 } 162 163 var addrs []*net.TCPAddr 164 165 for _, ipaddr := range ipaddrs { 166 if laddr != nil && !laddr.IP.IsUnspecified() && !matchAddrFamily(laddr.IP, ipaddr.IP) { 167 continue 168 } 169 addrs = append(addrs, &net.TCPAddr{ 170 IP: ipaddr.IP, 171 Port: portNum, 172 Zone: ipaddr.Zone, 173 }) 174 } 175 176 var primaries, fallbacks []*net.TCPAddr 177 if d.FallbackDelay >= 0 && network == "tcp" { 178 primaries, fallbacks = partition(addrs, func(a *net.TCPAddr) bool { 179 return a.IP.To4() != nil 180 }) 181 } else { 182 primaries = addrs 183 } 184 185 var c *net.TCPConn 186 if len(fallbacks) > 0 { 187 c, err = d.dialParallel(ctx, network, laddr, primaries, fallbacks, b) 188 } else { 189 c, err = d.dialSerial(ctx, network, laddr, primaries, b) 190 } 191 if err != nil { 192 return nil, err 193 } 194 195 if d.KeepAlive >= 0 { 196 c.SetKeepAlive(true) 197 ka := d.KeepAlive 198 if d.KeepAlive == 0 { 199 ka = defaultTCPKeepAlive 200 } 201 c.SetKeepAlivePeriod(ka) 202 } 203 return c, nil 204 } 205 206 // dialParallel races two copies of dialSerial, giving the first a 207 // head start. It returns the first established connection and 208 // closes the others. Otherwise it returns an error from the first 209 // primary address. 210 func (d *Dialer) dialParallel(ctx context.Context, network string, laddr *net.TCPAddr, primaries, fallbacks []*net.TCPAddr, b []byte) (*net.TCPConn, error) { 211 if len(fallbacks) == 0 { 212 return d.dialSerial(ctx, network, laddr, primaries, b) 213 } 214 215 returned := make(chan struct{}) 216 defer close(returned) 217 218 type dialResult struct { 219 *net.TCPConn 220 error 221 primary bool 222 done bool 223 } 224 results := make(chan dialResult) // unbuffered 225 226 startRacer := func(ctx context.Context, primary bool) { 227 ras := primaries 228 if !primary { 229 ras = fallbacks 230 } 231 c, err := d.dialSerial(ctx, network, laddr, ras, b) 232 select { 233 case results <- dialResult{TCPConn: c, error: err, primary: primary, done: true}: 234 case <-returned: 235 if c != nil { 236 c.Close() 237 } 238 } 239 } 240 241 var primary, fallback dialResult 242 243 // Start the main racer. 244 primaryCtx, primaryCancel := context.WithCancel(ctx) 245 defer primaryCancel() 246 go startRacer(primaryCtx, true) 247 248 // Start the timer for the fallback racer. 249 fallbackDelay := d.FallbackDelay 250 if fallbackDelay == 0 { 251 fallbackDelay = defaultFallbackDelay 252 } 253 fallbackTimer := time.NewTimer(fallbackDelay) 254 defer fallbackTimer.Stop() 255 256 for { 257 select { 258 case <-fallbackTimer.C: 259 fallbackCtx, fallbackCancel := context.WithCancel(ctx) 260 defer fallbackCancel() 261 go startRacer(fallbackCtx, false) 262 263 case res := <-results: 264 if res.error == nil { 265 return res.TCPConn, nil 266 } 267 if res.primary { 268 primary = res 269 } else { 270 fallback = res 271 } 272 if primary.done && fallback.done { 273 return nil, primary.error 274 } 275 if res.primary && fallbackTimer.Stop() { 276 // If we were able to stop the timer, that means it 277 // was running (hadn't yet started the fallback), but 278 // we just got an error on the primary path, so start 279 // the fallback immediately (in 0 nanoseconds). 280 fallbackTimer.Reset(0) 281 } 282 } 283 } 284 } 285 286 // dialSerial connects to a list of addresses in sequence, returning 287 // either the first successful connection, or the first error. 288 func (d *Dialer) dialSerial(ctx context.Context, network string, laddr *net.TCPAddr, ras []*net.TCPAddr, b []byte) (*net.TCPConn, error) { 289 var firstErr error // The error from the first address is most relevant. 290 291 for i, ra := range ras { 292 select { 293 case <-ctx.Done(): 294 return nil, &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: ctx.Err()} 295 default: 296 } 297 298 dialCtx := ctx 299 if deadline, hasDeadline := ctx.Deadline(); hasDeadline { 300 partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i) 301 if err != nil { 302 // Ran out of time. 303 if firstErr == nil { 304 firstErr = &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: err} 305 } 306 break 307 } 308 if partialDeadline.Before(deadline) { 309 var cancel context.CancelFunc 310 dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) 311 defer cancel() 312 } 313 } 314 315 ctrlCtxFn := d.ControlContext 316 if ctrlCtxFn == nil && d.Control != nil { 317 ctrlCtxFn = func(ctx context.Context, network, address string, c syscall.RawConn) error { 318 return d.Control(network, address, c) 319 } 320 } 321 322 c, err := d.dialSingle(dialCtx, network, laddr, ra, b, ctrlCtxFn) 323 if err == nil { 324 return c, nil 325 } 326 if firstErr == nil { 327 firstErr = &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: err} 328 } 329 } 330 331 if firstErr == nil { 332 firstErr = &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: errMissingAddress} 333 } 334 return nil, firstErr 335 } 336 337 func matchAddrFamily(x, y net.IP) bool { 338 return x.To4() != nil && y.To4() != nil || x.To16() != nil && x.To4() == nil && y.To16() != nil && y.To4() == nil 339 } 340 341 // partition divides an address list into two categories, using a 342 // strategy function to assign a boolean label to each address. 343 // The first address, and any with a matching label, are returned as 344 // primaries, while addresses with the opposite label are returned 345 // as fallbacks. For non-empty inputs, primaries is guaranteed to be 346 // non-empty. 347 func partition(addrs []*net.TCPAddr, strategy func(*net.TCPAddr) bool) (primaries, fallbacks []*net.TCPAddr) { 348 var primaryLabel bool 349 for i, addr := range addrs { 350 label := strategy(addr) 351 if i == 0 || label == primaryLabel { 352 primaryLabel = label 353 primaries = append(primaries, addr) 354 } else { 355 fallbacks = append(fallbacks, addr) 356 } 357 } 358 return 359 } 360 361 func minNonzeroTime(a, b time.Time) time.Time { 362 if a.IsZero() { 363 return b 364 } 365 if b.IsZero() || a.Before(b) { 366 return a 367 } 368 return b 369 } 370 371 // deadline returns the earliest of: 372 // - now+Timeout 373 // - d.Deadline 374 // - the context's deadline 375 // 376 // Or zero, if none of Timeout, Deadline, or context's deadline is set. 377 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) { 378 if d.Timeout != 0 { // including negative, for historical reasons 379 earliest = now.Add(d.Timeout) 380 } 381 if d, ok := ctx.Deadline(); ok { 382 earliest = minNonzeroTime(earliest, d) 383 } 384 return minNonzeroTime(earliest, d.Deadline) 385 } 386 387 // partialDeadline returns the deadline to use for a single address, 388 // when multiple addresses are pending. 389 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { 390 if deadline.IsZero() { 391 return deadline, nil 392 } 393 timeRemaining := deadline.Sub(now) 394 if timeRemaining <= 0 { 395 return time.Time{}, os.ErrDeadlineExceeded 396 } 397 // Tentatively allocate equal time to each remaining address. 398 timeout := timeRemaining / time.Duration(addrsRemaining) 399 // If the time per address is too short, steal from the end of the list. 400 const saneMinimum = 2 * time.Second 401 if timeout < saneMinimum { 402 if timeRemaining < saneMinimum { 403 timeout = timeRemaining 404 } else { 405 timeout = saneMinimum 406 } 407 } 408 return now.Add(timeout), nil 409 }