github.com/sagernet/tfo-go@v0.0.0-20231209031829-7b5343ac1dc6/tfo_windows.go (about)

     1  //lint:file-ignore U1000 linkname magic brings a lot of unused unexported fields.
     2  
     3  package tfo
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"net"
     9  	"os"
    10  	"sync"
    11  	"syscall"
    12  	"unsafe"
    13  
    14  	"golang.org/x/sys/windows"
    15  )
    16  
    17  func setIPv6Only(fd windows.Handle, family int, ipv6only bool) error {
    18  	if family == windows.AF_INET6 {
    19  		// Allow both IP versions even if the OS default
    20  		// is otherwise. Note that some operating systems
    21  		// never admit this option.
    22  		return windows.SetsockoptInt(fd, windows.IPPROTO_IPV6, windows.IPV6_V6ONLY, boolint(ipv6only))
    23  	}
    24  	return nil
    25  }
    26  
    27  func setNoDelay(fd windows.Handle, noDelay int) error {
    28  	return windows.SetsockoptInt(fd, windows.IPPROTO_TCP, windows.TCP_NODELAY, noDelay)
    29  }
    30  
    31  func setUpdateConnectContext(fd windows.Handle) error {
    32  	return windows.Setsockopt(fd, windows.SOL_SOCKET, windows.SO_UPDATE_CONNECT_CONTEXT, nil, 0)
    33  }
    34  
    35  //go:linkname sockaddrToTCP net.sockaddrToTCP
    36  func sockaddrToTCP(sa syscall.Sockaddr) net.Addr
    37  
    38  //go:linkname runtime_pollServerInit internal/poll.runtime_pollServerInit
    39  func runtime_pollServerInit()
    40  
    41  //go:linkname runtime_pollOpen internal/poll.runtime_pollOpen
    42  func runtime_pollOpen(fd uintptr) (uintptr, int)
    43  
    44  // Copied from src/internal/pool/fd_poll_runtime.go
    45  var serverInit sync.Once
    46  
    47  // operation contains superset of data necessary to perform all async IO.
    48  //
    49  // Copied from src/internal/pool/fd_windows.go
    50  type operation struct {
    51  	// Used by IOCP interface, it must be first field
    52  	// of the struct, as our code rely on it.
    53  	o syscall.Overlapped
    54  
    55  	// fields used by runtime.netpoll
    56  	runtimeCtx uintptr
    57  	mode       int32
    58  	errno      int32
    59  	qty        uint32
    60  
    61  	// fields used only by net package
    62  	fd     *pFD
    63  	buf    syscall.WSABuf
    64  	msg    windows.WSAMsg
    65  	sa     syscall.Sockaddr
    66  	rsa    *syscall.RawSockaddrAny
    67  	rsan   int32
    68  	handle syscall.Handle
    69  	flags  uint32
    70  	bufs   []syscall.WSABuf
    71  }
    72  
    73  //go:linkname execIO internal/poll.execIO
    74  func execIO(o *operation, submit func(o *operation) error) (int, error)
    75  
    76  // pFD is a file descriptor. The net and os packages embed this type in
    77  // a larger type representing a network connection or OS file.
    78  //
    79  // Copied from src/internal/pool/fd_windows.go
    80  type pFD struct {
    81  	fdmuS uint64
    82  	fdmuR uint32
    83  	fdmuW uint32
    84  
    85  	// System file descriptor. Immutable until Close.
    86  	Sysfd syscall.Handle
    87  
    88  	// Read operation.
    89  	rop operation
    90  	// Write operation.
    91  	wop operation
    92  
    93  	// I/O poller.
    94  	pd uintptr
    95  
    96  	// Used to implement pread/pwrite.
    97  	l sync.Mutex
    98  
    99  	// For console I/O.
   100  	lastbits       []byte   // first few bytes of the last incomplete rune in last write
   101  	readuint16     []uint16 // buffer to hold uint16s obtained with ReadConsole
   102  	readbyte       []byte   // buffer to hold decoding of readuint16 from utf16 to utf8
   103  	readbyteOffset int      // readbyte[readOffset:] is yet to be consumed with file.Read
   104  
   105  	// Semaphore signaled when file is closed.
   106  	csema uint32
   107  
   108  	skipSyncNotif bool
   109  
   110  	// Whether this is a streaming descriptor, as opposed to a
   111  	// packet-based descriptor like a UDP socket.
   112  	IsStream bool
   113  
   114  	// Whether a zero byte read indicates EOF. This is false for a
   115  	// message based socket connection.
   116  	ZeroReadIsEOF bool
   117  
   118  	// Whether this is a file rather than a network socket.
   119  	isFile bool
   120  
   121  	// The kind of this file.
   122  	kind byte
   123  }
   124  
   125  func (fd *pFD) init() error {
   126  	serverInit.Do(runtime_pollServerInit)
   127  	ctx, errno := runtime_pollOpen(uintptr(fd.Sysfd))
   128  	if errno != 0 {
   129  		return syscall.Errno(errno)
   130  	}
   131  	fd.pd = ctx
   132  	fd.rop.mode = 'r'
   133  	fd.wop.mode = 'w'
   134  	fd.rop.fd = fd
   135  	fd.wop.fd = fd
   136  	fd.rop.runtimeCtx = fd.pd
   137  	fd.wop.runtimeCtx = fd.pd
   138  	return nil
   139  }
   140  
   141  func (fd *pFD) ConnectEx(ra syscall.Sockaddr, b []byte) (n int, err error) {
   142  	fd.wop.sa = ra
   143  	n, err = execIO(&fd.wop, func(o *operation) error {
   144  		return syscall.ConnectEx(o.fd.Sysfd, o.sa, &b[0], uint32(len(b)), &o.qty, &o.o)
   145  	})
   146  	return
   147  }
   148  
   149  // Network file descriptor.
   150  //
   151  // Copied from src/net/fd_posix.go
   152  type netFD struct {
   153  	pfd pFD
   154  
   155  	// immutable until Close
   156  	family      int
   157  	sotype      int
   158  	isConnected bool // handshake completed or use of association with peer
   159  	net         string
   160  	laddr       net.Addr
   161  	raddr       net.Addr
   162  }
   163  
   164  func (fd *netFD) ctrlNetwork() string {
   165  	if fd.net == "tcp4" || fd.family == windows.AF_INET {
   166  		return "tcp4"
   167  	}
   168  	return "tcp6"
   169  }
   170  
   171  //go:linkname newFD net.newFD
   172  func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)
   173  
   174  type rawConn netFD
   175  
   176  func (c *rawConn) Control(f func(uintptr)) error {
   177  	f(uintptr(c.pfd.Sysfd))
   178  	return nil
   179  }
   180  
   181  func (c *rawConn) Read(f func(uintptr) bool) error {
   182  	f(uintptr(c.pfd.Sysfd))
   183  	return syscall.EWINDOWS
   184  }
   185  
   186  func (c *rawConn) Write(f func(uintptr) bool) error {
   187  	f(uintptr(c.pfd.Sysfd))
   188  	return syscall.EWINDOWS
   189  }
   190  
   191  func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) {
   192  	ltsa := (*tcpSockaddr)(laddr)
   193  	rtsa := (*tcpSockaddr)(raddr)
   194  	family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial")
   195  
   196  	var (
   197  		ip   net.IP
   198  		port int
   199  		zone string
   200  	)
   201  
   202  	if laddr != nil {
   203  		ip = laddr.IP
   204  		port = laddr.Port
   205  		zone = laddr.Zone
   206  	}
   207  
   208  	lsa, err := ipToSockaddr(family, ip, port, zone)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	rsa, err := rtsa.sockaddr(family)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	handle, err := windows.WSASocket(int32(family), windows.SOCK_STREAM, windows.IPPROTO_TCP, nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
   219  	if err != nil {
   220  		return nil, os.NewSyscallError("WSASocket", err)
   221  	}
   222  
   223  	fd, err := newFD(syscall.Handle(handle), family, windows.SOCK_STREAM, network)
   224  	if err != nil {
   225  		windows.Closesocket(handle)
   226  		return nil, err
   227  	}
   228  
   229  	tc := (*net.TCPConn)(unsafe.Pointer(&fd))
   230  
   231  	if err = setIPv6Only(handle, family, ipv6only); err != nil {
   232  		tc.Close()
   233  		return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err)
   234  	}
   235  
   236  	if err = setNoDelay(handle, 1); err != nil {
   237  		tc.Close()
   238  		return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err)
   239  	}
   240  
   241  	if err = setTFODialer(uintptr(handle)); err != nil {
   242  		if !d.Fallback || !errors.Is(err, os.ErrInvalid) {
   243  			tc.Close()
   244  			return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err)
   245  		}
   246  		runtimeDialTFOSupport.storeNone()
   247  	}
   248  
   249  	if ctrlCtxFn != nil {
   250  		if err = ctrlCtxFn(ctx, fd.ctrlNetwork(), raddr.String(), (*rawConn)(fd)); err != nil {
   251  			tc.Close()
   252  			return nil, err
   253  		}
   254  	}
   255  
   256  	if err = syscall.Bind(syscall.Handle(handle), lsa); err != nil {
   257  		tc.Close()
   258  		return nil, wrapSyscallError("bind", err)
   259  	}
   260  
   261  	if err = fd.pfd.init(); err != nil {
   262  		tc.Close()
   263  		return nil, err
   264  	}
   265  
   266  	if err = connWriteFunc(ctx, tc, func(c *net.TCPConn) error {
   267  		n, err := fd.pfd.ConnectEx(rsa, b)
   268  		if err != nil {
   269  			return os.NewSyscallError("connectex", err)
   270  		}
   271  
   272  		if err = setUpdateConnectContext(handle); err != nil {
   273  			return wrapSyscallError("setsockopt(SO_UPDATE_CONNECT_CONTEXT)", err)
   274  		}
   275  
   276  		lsa, err = syscall.Getsockname(syscall.Handle(handle))
   277  		if err != nil {
   278  			return wrapSyscallError("getsockname", err)
   279  		}
   280  		fd.laddr = sockaddrToTCP(lsa)
   281  
   282  		rsa, err = syscall.Getpeername(syscall.Handle(handle))
   283  		if err != nil {
   284  			return wrapSyscallError("getpeername", err)
   285  		}
   286  		fd.raddr = sockaddrToTCP(rsa)
   287  
   288  		if n < len(b) {
   289  			if _, err = tc.Write(b[n:]); err != nil {
   290  				return err
   291  			}
   292  		}
   293  
   294  		return nil
   295  	}); err != nil {
   296  		tc.Close()
   297  		return nil, err
   298  	}
   299  
   300  	return tc, nil
   301  }