github.com/metacubex/tfo-go@v0.0.0-20240228025757-be1269474a66/tfo.go (about)

     1  // Package tfo provides TCP Fast Open support for the [net] dialer and listener.
     2  //
     3  // The dial functions have an additional buffer parameter, which specifies data in SYN.
     4  // If the buffer is empty, TFO is not used.
     5  //
     6  // This package supports Linux, Windows, macOS, and FreeBSD.
     7  // On unsupported platforms, [ErrPlatformUnsupported] is returned.
     8  //
     9  // FreeBSD code is completely untested. Use at your own risk. Feedback is welcome.
    10  package tfo
    11  
    12  import (
    13  	"context"
    14  	"errors"
    15  	"net"
    16  	"os"
    17  	"sync/atomic"
    18  	"syscall"
    19  	"time"
    20  )
    21  
    22  var (
    23  	ErrPlatformUnsupported PlatformUnsupportedError
    24  	errMissingAddress      = errors.New("missing address")
    25  )
    26  
    27  // PlatformUnsupportedError is returned when tfo-go does not support TCP Fast Open on the current platform.
    28  type PlatformUnsupportedError struct{}
    29  
    30  func (PlatformUnsupportedError) Error() string {
    31  	return "tfo-go does not support TCP Fast Open on this platform"
    32  }
    33  
    34  func (PlatformUnsupportedError) Is(target error) bool {
    35  	return target == ErrUnsupported
    36  }
    37  
    38  var runtimeListenNoTFO atomic.Bool
    39  
    40  // ListenConfig wraps [net.ListenConfig] with TFO-related options.
    41  type ListenConfig struct {
    42  	net.ListenConfig
    43  
    44  	// Backlog specifies the maximum number of pending TFO connections on supported platforms.
    45  	// If the value is 0, Go std's listen(2) backlog (4096, as of the current version) is used.
    46  	// If the value is negative, TFO is disabled.
    47  	Backlog int
    48  
    49  	// DisableTFO controls whether TCP Fast Open is disabled when the Listen method is called.
    50  	// TFO is enabled by default, unless [ListenConfig.Backlog] is negative.
    51  	// Set to true to disable TFO and it will behave exactly the same as [net.ListenConfig].
    52  	DisableTFO bool
    53  
    54  	// Fallback controls whether to proceed without TFO when TFO is enabled but not supported
    55  	// on the system.
    56  	Fallback bool
    57  }
    58  
    59  func (lc *ListenConfig) tfoDisabled() bool {
    60  	return lc.Backlog < 0 || lc.DisableTFO
    61  }
    62  
    63  func (lc *ListenConfig) tfoNeedsFallback() bool {
    64  	return lc.Fallback && (comptimeNoTFO || runtimeListenNoTFO.Load())
    65  }
    66  
    67  // Listen is like [net.ListenConfig.Listen] but enables TFO whenever possible,
    68  // unless [ListenConfig.Backlog] is negative or [ListenConfig.DisableTFO] is set to true.
    69  func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) {
    70  	if lc.tfoDisabled() || !networkIsTCP(network) || lc.tfoNeedsFallback() {
    71  		return lc.ListenConfig.Listen(ctx, network, address)
    72  	}
    73  	return lc.listenTFO(ctx, network, address) // tfo_darwin.go, tfo_listen_generic.go, tfo_unsupported.go
    74  }
    75  
    76  // ListenContext is like [net.ListenContext] but enables TFO whenever possible.
    77  func ListenContext(ctx context.Context, network, address string) (net.Listener, error) {
    78  	var lc ListenConfig
    79  	return lc.Listen(ctx, network, address)
    80  }
    81  
    82  // Listen is like [net.Listen] but enables TFO whenever possible.
    83  func Listen(network, address string) (net.Listener, error) {
    84  	return ListenContext(context.Background(), network, address)
    85  }
    86  
    87  // ListenTCP is like [net.ListenTCP] but enables TFO whenever possible.
    88  func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) {
    89  	if !networkIsTCP(network) {
    90  		return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: opAddr(laddr), Err: net.UnknownNetworkError(network)}
    91  	}
    92  	var address string
    93  	if laddr != nil {
    94  		address = laddr.String()
    95  	}
    96  	var lc ListenConfig
    97  	ln, err := lc.listenTFO(context.Background(), network, address) // tfo_darwin.go, tfo_listen_generic.go, tfo_unsupported.go
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return ln.(*net.TCPListener), err
   102  }
   103  
   104  type dialTFOSupport uint32
   105  
   106  const (
   107  	dialTFOSupportDefault dialTFOSupport = iota
   108  	dialTFOSupportNone
   109  	dialTFOSupportLinuxSendto
   110  )
   111  
   112  type atomicDialTFOSupport struct {
   113  	v atomic.Uint32
   114  }
   115  
   116  func (a *atomicDialTFOSupport) load() dialTFOSupport {
   117  	return dialTFOSupport(a.v.Load())
   118  }
   119  
   120  func (a *atomicDialTFOSupport) storeNone() {
   121  	a.v.Store(uint32(dialTFOSupportNone))
   122  }
   123  
   124  var runtimeDialTFOSupport atomicDialTFOSupport
   125  
   126  // Dialer wraps [net.Dialer] with an additional option that allows you to disable TFO.
   127  type Dialer struct {
   128  	net.Dialer
   129  
   130  	// DisableTFO controls whether TCP Fast Open is disabled when the dial methods are called.
   131  	// TFO is enabled by default.
   132  	// Set to true to disable TFO and it will behave exactly the same as [net.Dialer].
   133  	DisableTFO bool
   134  
   135  	// Fallback controls whether to proceed without TFO when TFO is enabled but not supported
   136  	// on the system.
   137  	// On Linux this also controls whether the sendto(MSG_FASTOPEN) fallback path is tried
   138  	// before giving up on TFO.
   139  	Fallback bool
   140  }
   141  
   142  func (d *Dialer) dialAndWrite(ctx context.Context, network, address string, b []byte) (net.Conn, error) {
   143  	c, err := d.Dialer.DialContext(ctx, network, address)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	if err = netConnWriteBytes(ctx, c, b); err != nil {
   148  		c.Close()
   149  		return nil, err
   150  	}
   151  	return c, nil
   152  }
   153  
   154  func (d *Dialer) dialAndWriteTCPConn(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) {
   155  	c, err := d.Dialer.DialContext(ctx, network, address)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	if err = netConnWriteBytes(ctx, c, b); err != nil {
   160  		c.Close()
   161  		return nil, err
   162  	}
   163  	return c.(*net.TCPConn), nil
   164  }
   165  
   166  // DialContext is like [net.Dialer.DialContext] but enables TFO whenever possible,
   167  // unless [Dialer.DisableTFO] is set to true.
   168  func (d *Dialer) DialContext(ctx context.Context, network, address string, b []byte) (net.Conn, error) {
   169  	if len(b) == 0 {
   170  		return d.Dialer.DialContext(ctx, network, address)
   171  	}
   172  	if d.DisableTFO || !networkIsTCP(network) {
   173  		return d.dialAndWrite(ctx, network, address, b)
   174  	}
   175  	return d.dialTFO(ctx, network, address, b) // tfo_bsd+windows.go, tfo_linux.go, tfo_unsupported.go
   176  }
   177  
   178  // Dial is like [net.Dialer.Dial] but enables TFO whenever possible,
   179  // unless [Dialer.DisableTFO] is set to true.
   180  func (d *Dialer) Dial(network, address string, b []byte) (net.Conn, error) {
   181  	return d.DialContext(context.Background(), network, address, b)
   182  }
   183  
   184  // Dial is like [net.Dial] but enables TFO whenever possible.
   185  func Dial(network, address string, b []byte) (net.Conn, error) {
   186  	var d Dialer
   187  	return d.DialContext(context.Background(), network, address, b)
   188  }
   189  
   190  // DialTimeout is like [net.DialTimeout] but enables TFO whenever possible.
   191  func DialTimeout(network, address string, timeout time.Duration, b []byte) (net.Conn, error) {
   192  	var d Dialer
   193  	d.Timeout = timeout
   194  	return d.DialContext(context.Background(), network, address, b)
   195  }
   196  
   197  // DialTCP is like [net.DialTCP] but enables TFO whenever possible.
   198  func DialTCP(network string, laddr, raddr *net.TCPAddr, b []byte) (*net.TCPConn, error) {
   199  	if len(b) == 0 {
   200  		return net.DialTCP(network, laddr, raddr)
   201  	}
   202  	if !networkIsTCP(network) {
   203  		return nil, &net.OpError{Op: "dial", Net: network, Source: opAddr(laddr), Addr: opAddr(raddr), Err: net.UnknownNetworkError(network)}
   204  	}
   205  	if raddr == nil {
   206  		return nil, &net.OpError{Op: "dial", Net: network, Source: opAddr(laddr), Addr: nil, Err: errMissingAddress}
   207  	}
   208  	return dialTCPAddr(network, laddr, raddr, b) // tfo_bsd+windows.go, tfo_linux.go, tfo_unsupported.go
   209  }
   210  
   211  func networkIsTCP(network string) bool {
   212  	switch network {
   213  	case "tcp", "tcp4", "tcp6":
   214  		return true
   215  	default:
   216  		return false
   217  	}
   218  }
   219  
   220  func opAddr(a *net.TCPAddr) net.Addr {
   221  	if a == nil {
   222  		return nil
   223  	}
   224  	return a
   225  }
   226  
   227  // wrapSyscallError takes an error and a syscall name. If the error is
   228  // a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
   229  func wrapSyscallError(name string, err error) error {
   230  	if _, ok := err.(syscall.Errno); ok {
   231  		err = os.NewSyscallError(name, err)
   232  	}
   233  	return err
   234  }
   235  
   236  // aLongTimeAgo is a non-zero time, far in the past, used for immediate deadlines.
   237  var aLongTimeAgo = time.Unix(0, 0)
   238  
   239  // writeDeadliner allows cancellation of ongoing write operations.
   240  type writeDeadliner interface {
   241  	SetWriteDeadline(t time.Time) error
   242  }
   243  
   244  // connWriteFunc invokes the given function on a [writeDeadliner] to execute any arbitrary write operation.
   245  // If the given context can be canceled, it will spin up an interruptor goroutine to cancel the write operation
   246  // when the context is canceled.
   247  func connWriteFunc[C writeDeadliner](ctx context.Context, c C, fn func(C) error) (err error) {
   248  	if ctxDone := ctx.Done(); ctxDone != nil {
   249  		done := make(chan struct{})
   250  		interruptRes := make(chan error)
   251  
   252  		defer func() {
   253  			close(done)
   254  			if ctxErr := <-interruptRes; ctxErr != nil && err == nil {
   255  				err = ctxErr
   256  			}
   257  		}()
   258  
   259  		go func() {
   260  			select {
   261  			case <-ctxDone:
   262  				c.SetWriteDeadline(aLongTimeAgo)
   263  				interruptRes <- ctx.Err()
   264  			case <-done:
   265  				interruptRes <- nil
   266  			}
   267  		}()
   268  	}
   269  
   270  	return fn(c)
   271  }
   272  
   273  // netConnWriteBytes is a convenience wrapper around [connWriteFunc] for writing bytes to a [net.Conn].
   274  func netConnWriteBytes(ctx context.Context, c net.Conn, b []byte) error {
   275  	return connWriteFunc(ctx, c, func(c net.Conn) error {
   276  		_, err := c.Write(b)
   277  		return err
   278  	})
   279  }