github.com/database64128/tfo-go/v2@v2.2.0/tfo_linux.go (about)

     1  package tfo
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"syscall"
     8  
     9  	"golang.org/x/sys/unix"
    10  )
    11  
    12  const setTFODialerFromSocketSockoptName = "unreachable"
    13  
    14  func setTFODialerFromSocket(fd uintptr) error {
    15  	return nil
    16  }
    17  
    18  const sendtoImplicitConnectFlag = unix.MSG_FASTOPEN
    19  
    20  // doConnectCanFallback returns whether err from [doConnect] indicates lack of TFO support.
    21  func doConnectCanFallback(err error) bool {
    22  	// On Linux, calling sendto() on an unconnected TCP socket with zero or invalid flags
    23  	// returns -EPIPE. This indicates that the MSG_FASTOPEN flag is not recognized by the kernel.
    24  	//
    25  	// -EOPNOTSUPP is returned if the kernel recognizes the flag, but TFO is disabled via sysctl.
    26  	return err == syscall.EPIPE || err == syscall.EOPNOTSUPP
    27  }
    28  
    29  func (a *atomicDialTFOSupport) casLinuxSendto() bool {
    30  	return a.v.CompareAndSwap(uint32(dialTFOSupportDefault), uint32(dialTFOSupportLinuxSendto))
    31  }
    32  
    33  func (d *Dialer) dialTFO(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) {
    34  	if d.Fallback {
    35  		switch runtimeDialTFOSupport.load() {
    36  		case dialTFOSupportNone:
    37  			return d.dialAndWriteTCPConn(ctx, network, address, b)
    38  		case dialTFOSupportLinuxSendto:
    39  			return d.dialTFOFromSocket(ctx, network, address, b)
    40  		}
    41  	}
    42  
    43  	var canFallback bool
    44  	ctrlCtxFn := d.ControlContext
    45  	ctrlFn := d.Control
    46  	ld := *d
    47  	ld.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
    48  		switch {
    49  		case ctrlCtxFn != nil:
    50  			if err = ctrlCtxFn(ctx, network, address, c); err != nil {
    51  				return err
    52  			}
    53  		case ctrlFn != nil:
    54  			if err = ctrlFn(network, address, c); err != nil {
    55  				return err
    56  			}
    57  		}
    58  
    59  		if cerr := c.Control(func(fd uintptr) {
    60  			err = setTFODialer(fd)
    61  		}); cerr != nil {
    62  			return cerr
    63  		}
    64  
    65  		if err != nil {
    66  			if d.Fallback && errors.Is(err, errors.ErrUnsupported) {
    67  				canFallback = true
    68  			}
    69  			return wrapSyscallError("setsockopt(TCP_FASTOPEN_CONNECT)", err)
    70  		}
    71  		return nil
    72  	}
    73  
    74  	nc, err := ld.Dialer.DialContext(ctx, network, address)
    75  	if err != nil {
    76  		if d.Fallback && canFallback {
    77  			runtimeDialTFOSupport.casLinuxSendto()
    78  			return d.dialTFOFromSocket(ctx, network, address, b)
    79  		}
    80  		return nil, err
    81  	}
    82  	if err = netConnWriteBytes(ctx, nc, b); err != nil {
    83  		nc.Close()
    84  		return nil, err
    85  	}
    86  	return nc.(*net.TCPConn), nil
    87  }
    88  
    89  func dialTCPAddr(network string, laddr, raddr *net.TCPAddr, b []byte) (*net.TCPConn, error) {
    90  	d := Dialer{Dialer: net.Dialer{LocalAddr: laddr}}
    91  	return d.dialTFO(context.Background(), network, raddr.String(), b)
    92  }