github.com/database64128/tfo-go/v2@v2.2.0/tfo_windows.go (about) 1 package tfo 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "os" 8 "runtime" 9 "syscall" 10 "unsafe" 11 12 "golang.org/x/sys/windows" 13 ) 14 15 func setIPv6Only(fd windows.Handle, family int, ipv6only bool) error { 16 if family == windows.AF_INET6 { 17 // Allow both IP versions even if the OS default 18 // is otherwise. Note that some operating systems 19 // never admit this option. 20 return windows.SetsockoptInt(fd, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, boolint(ipv6only)) 21 } 22 return nil 23 } 24 25 func setNoDelay(fd windows.Handle, noDelay int) error { 26 return windows.SetsockoptInt(fd, windows.IPPROTO_TCP, windows.TCP_NODELAY, noDelay) 27 } 28 29 func setUpdateConnectContext(fd windows.Handle) error { 30 return windows.Setsockopt(fd, windows.SOL_SOCKET, windows.SO_UPDATE_CONNECT_CONTEXT, nil, 0) 31 } 32 33 func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) { 34 ltsa := (*tcpSockaddr)(laddr) 35 rtsa := (*tcpSockaddr)(raddr) 36 family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial") 37 38 var ( 39 ip net.IP 40 port int 41 zone string 42 ) 43 44 if laddr != nil { 45 ip = laddr.IP 46 port = laddr.Port 47 zone = laddr.Zone 48 } 49 50 lsa, err := ipToSockaddr(family, ip, port, zone) 51 if err != nil { 52 return nil, err 53 } 54 55 rsa, err := rtsa.sockaddr(family) 56 if err != nil { 57 return nil, err 58 } 59 60 handle, err := windows.WSASocket(int32(family), windows.SOCK_STREAM, windows.IPPROTO_TCP, nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT) 61 if err != nil { 62 return nil, os.NewSyscallError("WSASocket", err) 63 } 64 65 fd, err := newFD(syscall.Handle(handle), family, windows.SOCK_STREAM, network) 66 if err != nil { 67 windows.Closesocket(handle) 68 return nil, err 69 } 70 71 if err = setIPv6Only(handle, family, ipv6only); err != nil { 72 fd.Close() 73 return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err) 74 } 75 76 if err = setNoDelay(handle, 1); err != nil { 77 fd.Close() 78 return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err) 79 } 80 81 if err = setTFODialer(uintptr(handle)); err != nil { 82 if !d.Fallback || !errors.Is(err, errors.ErrUnsupported) { 83 fd.Close() 84 return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err) 85 } 86 runtimeDialTFOSupport.storeNone() 87 } 88 89 if ctrlCtxFn != nil { 90 if err = ctrlCtxFn(ctx, fd.ctrlNetwork(), raddr.String(), newRawConn(fd)); err != nil { 91 fd.Close() 92 return nil, err 93 } 94 } 95 96 if err = syscall.Bind(syscall.Handle(handle), lsa); err != nil { 97 fd.Close() 98 return nil, wrapSyscallError("bind", err) 99 } 100 101 if err = fd.init(); err != nil { 102 fd.Close() 103 return nil, err 104 } 105 106 if err = connWriteFunc(ctx, fd, func(fd *netFD) error { 107 n, err := fd.pfd.ConnectEx(rsa, b) 108 if err != nil { 109 return os.NewSyscallError("connectex", err) 110 } 111 112 if err = setUpdateConnectContext(handle); err != nil { 113 return wrapSyscallError("setsockopt(SO_UPDATE_CONNECT_CONTEXT)", err) 114 } 115 116 lsa, err = syscall.Getsockname(syscall.Handle(handle)) 117 if err != nil { 118 return wrapSyscallError("getsockname", err) 119 } 120 fd.laddr = sockaddrToTCP(lsa) 121 122 rsa, err = syscall.Getpeername(syscall.Handle(handle)) 123 if err != nil { 124 return wrapSyscallError("getpeername", err) 125 } 126 fd.raddr = sockaddrToTCP(rsa) 127 128 if n < len(b) { 129 if _, err = fd.Write(b[n:]); err != nil { 130 return err 131 } 132 } 133 134 return nil 135 }); err != nil { 136 fd.Close() 137 return nil, err 138 } 139 140 runtime.SetFinalizer(fd, netFDClose) 141 return (*net.TCPConn)(unsafe.Pointer(&fd)), nil 142 }