github.com/sagernet/tfo-go@v0.0.0-20231209031829-7b5343ac1dc6/tfo_linux.go (about)

     1  package tfo
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"os"
     8  	"syscall"
     9  
    10  	"golang.org/x/sys/unix"
    11  )
    12  
    13  const setTFODialerFromSocketSockoptName = "unreachable"
    14  
    15  func setTFODialerFromSocket(fd uintptr) error {
    16  	return nil
    17  }
    18  
    19  const sendtoImplicitConnectFlag = unix.MSG_FASTOPEN
    20  
    21  // doConnectCanFallback returns whether err from [doConnect] indicates lack of TFO support.
    22  func doConnectCanFallback(err error) bool {
    23  	// On Linux, calling sendto() on an unconnected TCP socket with zero or invalid flags
    24  	// returns -EPIPE. This indicates that the MSG_FASTOPEN flag is not recognized by the kernel.
    25  	//
    26  	// -EOPNOTSUPP is returned if the kernel recognizes the flag, but TFO is disabled via sysctl.
    27  	return err == syscall.EPIPE || err == syscall.EOPNOTSUPP
    28  }
    29  
    30  func (a *atomicDialTFOSupport) casLinuxSendto() bool {
    31  	return a.v.CompareAndSwap(uint32(dialTFOSupportDefault), uint32(dialTFOSupportLinuxSendto))
    32  }
    33  
    34  func (d *Dialer) dialTFO(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) {
    35  	if d.Fallback {
    36  		switch runtimeDialTFOSupport.load() {
    37  		case dialTFOSupportNone:
    38  			return d.dialAndWriteTCPConn(ctx, network, address, b)
    39  		case dialTFOSupportLinuxSendto:
    40  			return d.dialTFOFromSocket(ctx, network, address, b)
    41  		}
    42  	}
    43  
    44  	var canFallback bool
    45  	ctrlCtxFn := d.ControlContext
    46  	ctrlFn := d.Control
    47  	ld := *d
    48  	ld.ControlContext = func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
    49  		switch {
    50  		case ctrlCtxFn != nil:
    51  			if err = ctrlCtxFn(ctx, network, address, c); err != nil {
    52  				return err
    53  			}
    54  		case ctrlFn != nil:
    55  			if err = ctrlFn(network, address, c); err != nil {
    56  				return err
    57  			}
    58  		}
    59  
    60  		if cerr := c.Control(func(fd uintptr) {
    61  			err = setTFODialer(fd)
    62  		}); cerr != nil {
    63  			return cerr
    64  		}
    65  
    66  		if err != nil {
    67  			if d.Fallback && errors.Is(err, os.ErrInvalid) {
    68  				canFallback = true
    69  			}
    70  			return wrapSyscallError("setsockopt(TCP_FASTOPEN_CONNECT)", err)
    71  		}
    72  		return nil
    73  	}
    74  
    75  	nc, err := ld.Dialer.DialContext(ctx, network, address)
    76  	if err != nil {
    77  		if d.Fallback && canFallback {
    78  			runtimeDialTFOSupport.casLinuxSendto()
    79  			return d.dialTFOFromSocket(ctx, network, address, b)
    80  		}
    81  		return nil, err
    82  	}
    83  	if err = netConnWriteBytes(ctx, nc, b); err != nil {
    84  		nc.Close()
    85  		return nil, err
    86  	}
    87  	return nc.(*net.TCPConn), nil
    88  }
    89  
    90  func dialTCPAddr(network string, laddr, raddr *net.TCPAddr, b []byte) (*net.TCPConn, error) {
    91  	d := Dialer{Dialer: net.Dialer{LocalAddr: laddr}}
    92  	return d.dialTFO(context.Background(), network, raddr.String(), b)
    93  }