github.com/metacubex/tfo-go@v0.0.0-20240228025757-be1269474a66/tfo_windows.go (about)

     1  package tfo
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"os"
     8  	"runtime"
     9  	"syscall"
    10  	"unsafe"
    11  
    12  	"golang.org/x/sys/windows"
    13  )
    14  
    15  func setIPv6Only(fd windows.Handle, family int, ipv6only bool) error {
    16  	if family == windows.AF_INET6 {
    17  		// Allow both IP versions even if the OS default
    18  		// is otherwise. Note that some operating systems
    19  		// never admit this option.
    20  		return windows.SetsockoptInt(fd, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, boolint(ipv6only))
    21  	}
    22  	return nil
    23  }
    24  
    25  func setNoDelay(fd windows.Handle, noDelay int) error {
    26  	return windows.SetsockoptInt(fd, windows.IPPROTO_TCP, windows.TCP_NODELAY, noDelay)
    27  }
    28  
    29  func setUpdateConnectContext(fd windows.Handle) error {
    30  	return windows.Setsockopt(fd, windows.SOL_SOCKET, windows.SO_UPDATE_CONNECT_CONTEXT, nil, 0)
    31  }
    32  
    33  func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) {
    34  	ltsa := (*tcpSockaddr)(laddr)
    35  	rtsa := (*tcpSockaddr)(raddr)
    36  	family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial")
    37  
    38  	var (
    39  		ip   net.IP
    40  		port int
    41  		zone string
    42  	)
    43  
    44  	if laddr != nil {
    45  		ip = laddr.IP
    46  		port = laddr.Port
    47  		zone = laddr.Zone
    48  	}
    49  
    50  	lsa, err := ipToSockaddr(family, ip, port, zone)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	rsa, err := rtsa.sockaddr(family)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	handle, err := windows.WSASocket(int32(family), windows.SOCK_STREAM, windows.IPPROTO_TCP, nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
    61  	if err != nil {
    62  		return nil, os.NewSyscallError("WSASocket", err)
    63  	}
    64  
    65  	fd, err := newFD(syscall.Handle(handle), family, windows.SOCK_STREAM, network)
    66  	if err != nil {
    67  		windows.Closesocket(handle)
    68  		return nil, err
    69  	}
    70  
    71  	if err = setIPv6Only(handle, family, ipv6only); err != nil {
    72  		fd.Close()
    73  		return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err)
    74  	}
    75  
    76  	if err = setNoDelay(handle, 1); err != nil {
    77  		fd.Close()
    78  		return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err)
    79  	}
    80  
    81  	if err = setTFODialer(uintptr(handle)); err != nil {
    82  		if !d.Fallback || !errors.Is(err, ErrUnsupported) {
    83  			fd.Close()
    84  			return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err)
    85  		}
    86  		runtimeDialTFOSupport.storeNone()
    87  	}
    88  
    89  	if ctrlCtxFn != nil {
    90  		if err = ctrlCtxFn(ctx, fd.ctrlNetwork(), raddr.String(), newRawConn(fd)); err != nil {
    91  			fd.Close()
    92  			return nil, err
    93  		}
    94  	}
    95  
    96  	if err = syscall.Bind(syscall.Handle(handle), lsa); err != nil {
    97  		fd.Close()
    98  		return nil, wrapSyscallError("bind", err)
    99  	}
   100  
   101  	if err = fd.init(); err != nil {
   102  		fd.Close()
   103  		return nil, err
   104  	}
   105  
   106  	if err = connWriteFunc(ctx, fd, func(fd *netFD) error {
   107  		n, err := fd.pfd.ConnectEx(rsa, b)
   108  		if err != nil {
   109  			return os.NewSyscallError("connectex", err)
   110  		}
   111  
   112  		if err = setUpdateConnectContext(handle); err != nil {
   113  			return wrapSyscallError("setsockopt(SO_UPDATE_CONNECT_CONTEXT)", err)
   114  		}
   115  
   116  		lsa, err = syscall.Getsockname(syscall.Handle(handle))
   117  		if err != nil {
   118  			return wrapSyscallError("getsockname", err)
   119  		}
   120  		fd.laddr = sockaddrToTCP(lsa)
   121  
   122  		rsa, err = syscall.Getpeername(syscall.Handle(handle))
   123  		if err != nil {
   124  			return wrapSyscallError("getpeername", err)
   125  		}
   126  		fd.raddr = sockaddrToTCP(rsa)
   127  
   128  		if n < len(b) {
   129  			if _, err = fd.Write(b[n:]); err != nil {
   130  				return err
   131  			}
   132  		}
   133  
   134  		return nil
   135  	}); err != nil {
   136  		fd.Close()
   137  		return nil, err
   138  	}
   139  
   140  	runtime.SetFinalizer(fd, netFDClose)
   141  	return (*net.TCPConn)(unsafe.Pointer(&fd)), nil
   142  }