github.com/sagernet/tfo-go@v0.0.0-20231209031829-7b5343ac1dc6/tfo_bsd+linux.go (about)

     1  //go:build darwin || freebsd || linux
     2  
     3  package tfo
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"net"
     9  	"os"
    10  	"syscall"
    11  
    12  	"golang.org/x/sys/unix"
    13  )
    14  
    15  func setIPv6Only(fd int, family int, ipv6only bool) error {
    16  	if family == unix.AF_INET6 {
    17  		// Allow both IP versions even if the OS default
    18  		// is otherwise. Note that some operating systems
    19  		// never admit this option.
    20  		return unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, boolint(ipv6only))
    21  	}
    22  	return nil
    23  }
    24  
    25  func setNoDelay(fd int, noDelay int) error {
    26  	return unix.SetsockoptInt(fd, unix.IPPROTO_TCP, unix.TCP_NODELAY, noDelay)
    27  }
    28  
    29  func ctrlNetwork(network string, family int) string {
    30  	if network == "tcp4" || family == unix.AF_INET {
    31  		return "tcp4"
    32  	}
    33  	return "tcp6"
    34  }
    35  
    36  func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) {
    37  	ltsa := (*tcpSockaddr)(laddr)
    38  	rtsa := (*tcpSockaddr)(raddr)
    39  	family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial")
    40  
    41  	lsa, err := ltsa.sockaddr(family)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	rsa, err := rtsa.sockaddr(family)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	fd, err := d.socket(family)
    52  	if err != nil {
    53  		return nil, wrapSyscallError("socket", err)
    54  	}
    55  
    56  	if err = d.setIPv6Only(fd, family, ipv6only); err != nil {
    57  		unix.Close(fd)
    58  		return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err)
    59  	}
    60  
    61  	if err = setNoDelay(fd, 1); err != nil {
    62  		unix.Close(fd)
    63  		return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err)
    64  	}
    65  
    66  	if err = setTFODialerFromSocket(uintptr(fd)); err != nil {
    67  		if !d.Fallback || !errors.Is(err, os.ErrInvalid) {
    68  			unix.Close(fd)
    69  			return nil, wrapSyscallError("setsockopt("+setTFODialerFromSocketSockoptName+")", err)
    70  		}
    71  		runtimeDialTFOSupport.storeNone()
    72  	}
    73  
    74  	f := os.NewFile(uintptr(fd), "")
    75  	defer f.Close()
    76  
    77  	rawConn, err := f.SyscallConn()
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	if ctrlCtxFn != nil {
    83  		if err = ctrlCtxFn(ctx, ctrlNetwork(network, family), raddr.String(), rawConn); err != nil {
    84  			return nil, err
    85  		}
    86  	}
    87  
    88  	if laddr != nil {
    89  		if cErr := rawConn.Control(func(fd uintptr) {
    90  			err = syscall.Bind(int(fd), lsa)
    91  		}); cErr != nil {
    92  			return nil, cErr
    93  		}
    94  		if err != nil {
    95  			return nil, wrapSyscallError("bind", err)
    96  		}
    97  	}
    98  
    99  	var (
   100  		n           int
   101  		canFallback bool
   102  	)
   103  
   104  	if err = connWriteFunc(ctx, f, func(f *os.File) (err error) {
   105  		n, canFallback, err = connect(rawConn, rsa, b)
   106  		return err
   107  	}); err != nil {
   108  		if d.Fallback && canFallback {
   109  			runtimeDialTFOSupport.storeNone()
   110  			return d.dialAndWriteTCPConn(ctx, network, raddr.String(), b)
   111  		}
   112  		return nil, err
   113  	}
   114  
   115  	c, err := net.FileConn(f)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	if n < len(b) {
   121  		if err = netConnWriteBytes(ctx, c, b[n:]); err != nil {
   122  			c.Close()
   123  			return nil, err
   124  		}
   125  	}
   126  
   127  	return c.(*net.TCPConn), err
   128  }
   129  
   130  func connect(rawConn syscall.RawConn, rsa syscall.Sockaddr, b []byte) (n int, canFallback bool, err error) {
   131  	var done bool
   132  
   133  	if perr := rawConn.Write(func(fd uintptr) bool {
   134  		if done {
   135  			return true
   136  		}
   137  
   138  		n, err = doConnect(fd, rsa, b)
   139  		if err == unix.EINPROGRESS {
   140  			done = true
   141  			err = nil
   142  			return false
   143  		}
   144  		return true
   145  	}); perr != nil {
   146  		return 0, false, perr
   147  	}
   148  
   149  	if err != nil {
   150  		return 0, doConnectCanFallback(err), wrapSyscallError(connectSyscallName, err)
   151  	}
   152  
   153  	if perr := rawConn.Control(func(fd uintptr) {
   154  		err = getSocketError(int(fd), connectSyscallName)
   155  	}); perr != nil {
   156  		return 0, false, perr
   157  	}
   158  
   159  	return
   160  }
   161  
   162  func getSocketError(fd int, call string) error {
   163  	nerr, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
   164  	if err != nil {
   165  		return wrapSyscallError("getsockopt", err)
   166  	}
   167  	if nerr != 0 {
   168  		return os.NewSyscallError(call, syscall.Errno(nerr))
   169  	}
   170  	return nil
   171  }