github.com/sagernet/tfo-go@v0.0.0-20231209031829-7b5343ac1dc6/tfo_supported.go (about)

     1  //go:build darwin || freebsd || linux || windows
     2  
     3  package tfo
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"os"
     9  	"syscall"
    10  	"time"
    11  	_ "unsafe"
    12  )
    13  
    14  const comptimeNoTFO = false
    15  
    16  const (
    17  	defaultTCPKeepAlive  = 15 * time.Second
    18  	defaultFallbackDelay = 300 * time.Millisecond
    19  )
    20  
    21  // Boolean to int.
    22  func boolint(b bool) int {
    23  	if b {
    24  		return 1
    25  	}
    26  	return 0
    27  }
    28  
    29  // A sockaddr represents a TCP, UDP, IP or Unix network endpoint
    30  // address that can be converted into a syscall.Sockaddr.
    31  //
    32  // Copied from src/net/sockaddr_posix.go
    33  type sockaddr interface {
    34  	net.Addr
    35  
    36  	// family returns the platform-dependent address family
    37  	// identifier.
    38  	family() int
    39  
    40  	// isWildcard reports whether the address is a wildcard
    41  	// address.
    42  	isWildcard() bool
    43  
    44  	// sockaddr returns the address converted into a syscall
    45  	// sockaddr type that implements syscall.Sockaddr
    46  	// interface. It returns a nil interface when the address is
    47  	// nil.
    48  	sockaddr(family int) (syscall.Sockaddr, error)
    49  
    50  	// toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
    51  	toLocal(net string) sockaddr
    52  }
    53  
    54  type tcpSockaddr net.TCPAddr
    55  
    56  func (a *tcpSockaddr) Network() string {
    57  	return "tcp"
    58  }
    59  
    60  func (a *tcpSockaddr) String() string {
    61  	return (*net.TCPAddr)(a).String()
    62  }
    63  
    64  // Copied from src/net/tcpsock_posix.go
    65  func (a *tcpSockaddr) family() int {
    66  	if a == nil || len(a.IP) <= net.IPv4len {
    67  		return syscall.AF_INET
    68  	}
    69  	if a.IP.To4() != nil {
    70  		return syscall.AF_INET
    71  	}
    72  	return syscall.AF_INET6
    73  }
    74  
    75  // Copied from src/net/tcpsock_posix.go
    76  func (a *tcpSockaddr) isWildcard() bool {
    77  	if a == nil || a.IP == nil {
    78  		return true
    79  	}
    80  	return a.IP.IsUnspecified()
    81  }
    82  
    83  //go:linkname ipToSockaddr net.ipToSockaddr
    84  func ipToSockaddr(family int, ip net.IP, port int, zone string) (syscall.Sockaddr, error)
    85  
    86  // Copied from src/net/tcpsock_posix.go
    87  func (a *tcpSockaddr) sockaddr(family int) (syscall.Sockaddr, error) {
    88  	if a == nil {
    89  		return nil, nil
    90  	}
    91  	return ipToSockaddr(family, a.IP, a.Port, a.Zone)
    92  }
    93  
    94  //go:linkname loopbackIP net.loopbackIP
    95  func loopbackIP(net string) net.IP
    96  
    97  // Modified from src/net/tcpsock_posix.go
    98  func (a *tcpSockaddr) toLocal(net string) sockaddr {
    99  	la := *a
   100  	la.IP = loopbackIP(net)
   101  	return &la
   102  }
   103  
   104  //go:linkname favoriteAddrFamily net.favoriteAddrFamily
   105  func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool)
   106  
   107  func (d *Dialer) dialTFOFromSocket(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) {
   108  	if ctx == nil {
   109  		panic("nil context")
   110  	}
   111  	deadline := d.deadline(ctx, time.Now())
   112  	if !deadline.IsZero() {
   113  		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
   114  			subCtx, cancel := context.WithDeadline(ctx, deadline)
   115  			defer cancel()
   116  			ctx = subCtx
   117  		}
   118  	}
   119  	if oldCancel := d.Cancel; oldCancel != nil {
   120  		subCtx, cancel := context.WithCancel(ctx)
   121  		defer cancel()
   122  		go func() {
   123  			select {
   124  			case <-oldCancel:
   125  				cancel()
   126  			case <-subCtx.Done():
   127  			}
   128  		}()
   129  		ctx = subCtx
   130  	}
   131  
   132  	var laddr *net.TCPAddr
   133  	if d.LocalAddr != nil {
   134  		la, ok := d.LocalAddr.(*net.TCPAddr)
   135  		if !ok {
   136  			return nil, &net.OpError{
   137  				Op:     "dial",
   138  				Net:    network,
   139  				Source: nil,
   140  				Addr:   nil,
   141  				Err: &net.AddrError{
   142  					Err:  "mismatched local address type",
   143  					Addr: d.LocalAddr.String(),
   144  				},
   145  			}
   146  		}
   147  		laddr = la
   148  	}
   149  
   150  	host, port, err := net.SplitHostPort(address)
   151  	if err != nil {
   152  		return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
   153  	}
   154  	portNum, err := d.Resolver.LookupPort(ctx, network, port)
   155  	if err != nil {
   156  		return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
   157  	}
   158  	ipaddrs, err := d.Resolver.LookupIPAddr(ctx, host)
   159  	if err != nil {
   160  		return nil, &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
   161  	}
   162  
   163  	var addrs []*net.TCPAddr
   164  
   165  	for _, ipaddr := range ipaddrs {
   166  		if laddr != nil && !laddr.IP.IsUnspecified() && !matchAddrFamily(laddr.IP, ipaddr.IP) {
   167  			continue
   168  		}
   169  		addrs = append(addrs, &net.TCPAddr{
   170  			IP:   ipaddr.IP,
   171  			Port: portNum,
   172  			Zone: ipaddr.Zone,
   173  		})
   174  	}
   175  
   176  	var primaries, fallbacks []*net.TCPAddr
   177  	if d.FallbackDelay >= 0 && network == "tcp" {
   178  		primaries, fallbacks = partition(addrs, func(a *net.TCPAddr) bool {
   179  			return a.IP.To4() != nil
   180  		})
   181  	} else {
   182  		primaries = addrs
   183  	}
   184  
   185  	var c *net.TCPConn
   186  	if len(fallbacks) > 0 {
   187  		c, err = d.dialParallel(ctx, network, laddr, primaries, fallbacks, b)
   188  	} else {
   189  		c, err = d.dialSerial(ctx, network, laddr, primaries, b)
   190  	}
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	if d.KeepAlive >= 0 {
   196  		c.SetKeepAlive(true)
   197  		ka := d.KeepAlive
   198  		if d.KeepAlive == 0 {
   199  			ka = defaultTCPKeepAlive
   200  		}
   201  		c.SetKeepAlivePeriod(ka)
   202  	}
   203  	return c, nil
   204  }
   205  
   206  // dialParallel races two copies of dialSerial, giving the first a
   207  // head start. It returns the first established connection and
   208  // closes the others. Otherwise it returns an error from the first
   209  // primary address.
   210  func (d *Dialer) dialParallel(ctx context.Context, network string, laddr *net.TCPAddr, primaries, fallbacks []*net.TCPAddr, b []byte) (*net.TCPConn, error) {
   211  	if len(fallbacks) == 0 {
   212  		return d.dialSerial(ctx, network, laddr, primaries, b)
   213  	}
   214  
   215  	returned := make(chan struct{})
   216  	defer close(returned)
   217  
   218  	type dialResult struct {
   219  		*net.TCPConn
   220  		error
   221  		primary bool
   222  		done    bool
   223  	}
   224  	results := make(chan dialResult) // unbuffered
   225  
   226  	startRacer := func(ctx context.Context, primary bool) {
   227  		ras := primaries
   228  		if !primary {
   229  			ras = fallbacks
   230  		}
   231  		c, err := d.dialSerial(ctx, network, laddr, ras, b)
   232  		select {
   233  		case results <- dialResult{TCPConn: c, error: err, primary: primary, done: true}:
   234  		case <-returned:
   235  			if c != nil {
   236  				c.Close()
   237  			}
   238  		}
   239  	}
   240  
   241  	var primary, fallback dialResult
   242  
   243  	// Start the main racer.
   244  	primaryCtx, primaryCancel := context.WithCancel(ctx)
   245  	defer primaryCancel()
   246  	go startRacer(primaryCtx, true)
   247  
   248  	// Start the timer for the fallback racer.
   249  	fallbackDelay := d.FallbackDelay
   250  	if fallbackDelay == 0 {
   251  		fallbackDelay = defaultFallbackDelay
   252  	}
   253  	fallbackTimer := time.NewTimer(fallbackDelay)
   254  	defer fallbackTimer.Stop()
   255  
   256  	for {
   257  		select {
   258  		case <-fallbackTimer.C:
   259  			fallbackCtx, fallbackCancel := context.WithCancel(ctx)
   260  			defer fallbackCancel()
   261  			go startRacer(fallbackCtx, false)
   262  
   263  		case res := <-results:
   264  			if res.error == nil {
   265  				return res.TCPConn, nil
   266  			}
   267  			if res.primary {
   268  				primary = res
   269  			} else {
   270  				fallback = res
   271  			}
   272  			if primary.done && fallback.done {
   273  				return nil, primary.error
   274  			}
   275  			if res.primary && fallbackTimer.Stop() {
   276  				// If we were able to stop the timer, that means it
   277  				// was running (hadn't yet started the fallback), but
   278  				// we just got an error on the primary path, so start
   279  				// the fallback immediately (in 0 nanoseconds).
   280  				fallbackTimer.Reset(0)
   281  			}
   282  		}
   283  	}
   284  }
   285  
   286  // dialSerial connects to a list of addresses in sequence, returning
   287  // either the first successful connection, or the first error.
   288  func (d *Dialer) dialSerial(ctx context.Context, network string, laddr *net.TCPAddr, ras []*net.TCPAddr, b []byte) (*net.TCPConn, error) {
   289  	var firstErr error // The error from the first address is most relevant.
   290  
   291  	for i, ra := range ras {
   292  		select {
   293  		case <-ctx.Done():
   294  			return nil, &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: ctx.Err()}
   295  		default:
   296  		}
   297  
   298  		dialCtx := ctx
   299  		if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
   300  			partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
   301  			if err != nil {
   302  				// Ran out of time.
   303  				if firstErr == nil {
   304  					firstErr = &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: err}
   305  				}
   306  				break
   307  			}
   308  			if partialDeadline.Before(deadline) {
   309  				var cancel context.CancelFunc
   310  				dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
   311  				defer cancel()
   312  			}
   313  		}
   314  
   315  		ctrlCtxFn := d.ControlContext
   316  		if ctrlCtxFn == nil && d.Control != nil {
   317  			ctrlCtxFn = func(ctx context.Context, network, address string, c syscall.RawConn) error {
   318  				return d.Control(network, address, c)
   319  			}
   320  		}
   321  
   322  		c, err := d.dialSingle(dialCtx, network, laddr, ra, b, ctrlCtxFn)
   323  		if err == nil {
   324  			return c, nil
   325  		}
   326  		if firstErr == nil {
   327  			firstErr = &net.OpError{Op: "dial", Net: network, Source: d.LocalAddr, Addr: ra, Err: err}
   328  		}
   329  	}
   330  
   331  	if firstErr == nil {
   332  		firstErr = &net.OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: errMissingAddress}
   333  	}
   334  	return nil, firstErr
   335  }
   336  
   337  func matchAddrFamily(x, y net.IP) bool {
   338  	return x.To4() != nil && y.To4() != nil || x.To16() != nil && x.To4() == nil && y.To16() != nil && y.To4() == nil
   339  }
   340  
   341  // partition divides an address list into two categories, using a
   342  // strategy function to assign a boolean label to each address.
   343  // The first address, and any with a matching label, are returned as
   344  // primaries, while addresses with the opposite label are returned
   345  // as fallbacks. For non-empty inputs, primaries is guaranteed to be
   346  // non-empty.
   347  func partition(addrs []*net.TCPAddr, strategy func(*net.TCPAddr) bool) (primaries, fallbacks []*net.TCPAddr) {
   348  	var primaryLabel bool
   349  	for i, addr := range addrs {
   350  		label := strategy(addr)
   351  		if i == 0 || label == primaryLabel {
   352  			primaryLabel = label
   353  			primaries = append(primaries, addr)
   354  		} else {
   355  			fallbacks = append(fallbacks, addr)
   356  		}
   357  	}
   358  	return
   359  }
   360  
   361  func minNonzeroTime(a, b time.Time) time.Time {
   362  	if a.IsZero() {
   363  		return b
   364  	}
   365  	if b.IsZero() || a.Before(b) {
   366  		return a
   367  	}
   368  	return b
   369  }
   370  
   371  // deadline returns the earliest of:
   372  //   - now+Timeout
   373  //   - d.Deadline
   374  //   - the context's deadline
   375  //
   376  // Or zero, if none of Timeout, Deadline, or context's deadline is set.
   377  func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
   378  	if d.Timeout != 0 { // including negative, for historical reasons
   379  		earliest = now.Add(d.Timeout)
   380  	}
   381  	if d, ok := ctx.Deadline(); ok {
   382  		earliest = minNonzeroTime(earliest, d)
   383  	}
   384  	return minNonzeroTime(earliest, d.Deadline)
   385  }
   386  
   387  // partialDeadline returns the deadline to use for a single address,
   388  // when multiple addresses are pending.
   389  func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
   390  	if deadline.IsZero() {
   391  		return deadline, nil
   392  	}
   393  	timeRemaining := deadline.Sub(now)
   394  	if timeRemaining <= 0 {
   395  		return time.Time{}, os.ErrDeadlineExceeded
   396  	}
   397  	// Tentatively allocate equal time to each remaining address.
   398  	timeout := timeRemaining / time.Duration(addrsRemaining)
   399  	// If the time per address is too short, steal from the end of the list.
   400  	const saneMinimum = 2 * time.Second
   401  	if timeout < saneMinimum {
   402  		if timeRemaining < saneMinimum {
   403  			timeout = timeRemaining
   404  		} else {
   405  			timeout = saneMinimum
   406  		}
   407  	}
   408  	return now.Add(timeout), nil
   409  }