github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/copy_direct_windows.go (about)

     1  package bufio
     2  
     3  import (
     4  	"io"
     5  	"net/netip"
     6  	"os"
     7  	"syscall"
     8  	"unsafe"
     9  
    10  	"github.com/sagernet/sing/common/buf"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  
    15  	"golang.org/x/sys/windows"
    16  )
    17  
    18  var modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
    19  
    20  var procrecv = modws2_32.NewProc("recv")
    21  
    22  // Do the interface allocations only once for common
    23  // Errno values.
    24  const (
    25  	errnoERROR_IO_PENDING = 997
    26  )
    27  
    28  var (
    29  	errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
    30  	errERROR_EINVAL     error = syscall.EINVAL
    31  )
    32  
    33  // errnoErr returns common boxed Errno values, to prevent
    34  // allocations at runtime.
    35  func errnoErr(e syscall.Errno) error {
    36  	switch e {
    37  	case 0:
    38  		return errERROR_EINVAL
    39  	case errnoERROR_IO_PENDING:
    40  		return errERROR_IO_PENDING
    41  	}
    42  	// TODO: add more here, after collecting data on the common
    43  	// error values see on Windows. (perhaps when running
    44  	// all.bat?)
    45  	return e
    46  }
    47  
    48  func recv(s windows.Handle, buf []byte, flags int32) (n int32, err error) {
    49  	var _p0 *byte
    50  	if len(buf) > 0 {
    51  		_p0 = &buf[0]
    52  	}
    53  	r0, _, e1 := syscall.SyscallN(procrecv.Addr(), uintptr(s), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), uintptr(flags))
    54  	n = int32(r0)
    55  	if n == -1 {
    56  		err = errnoErr(e1)
    57  	}
    58  	return
    59  }
    60  
    61  var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
    62  
    63  type syscallReadWaiter struct {
    64  	rawConn  syscall.RawConn
    65  	readErr  error
    66  	readFunc func(fd uintptr) (done bool)
    67  	hasData  bool
    68  	buffer   *buf.Buffer
    69  	options  N.ReadWaitOptions
    70  }
    71  
    72  func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
    73  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
    74  		rawConn, err := syscallConn.SyscallConn()
    75  		if err == nil {
    76  			return &syscallReadWaiter{rawConn: rawConn}, true
    77  		}
    78  	}
    79  	return nil, false
    80  }
    81  
    82  func (w *syscallReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
    83  	w.options = options
    84  	w.readFunc = func(fd uintptr) (done bool) {
    85  		if !w.hasData {
    86  			w.hasData = true
    87  			// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
    88  			// socket is readable if we return false. So the `recv` syscall will not block the system thread.
    89  			return false
    90  		}
    91  		buffer := w.options.NewBuffer()
    92  		var readN int32
    93  		readN, w.readErr = recv(windows.Handle(fd), buffer.FreeBytes(), 0)
    94  		if readN > 0 {
    95  			buffer.Truncate(int(readN))
    96  			w.options.PostReturn(buffer)
    97  			w.buffer = buffer
    98  		} else {
    99  			buffer.Release()
   100  		}
   101  		if w.readErr == windows.WSAEWOULDBLOCK {
   102  			return false
   103  		}
   104  		if readN == 0 && w.readErr == nil {
   105  			w.readErr = io.EOF
   106  		}
   107  		w.hasData = false
   108  		return true
   109  	}
   110  	return false
   111  }
   112  
   113  func (w *syscallReadWaiter) WaitReadBuffer() (buffer *buf.Buffer, err error) {
   114  	if w.readFunc == nil {
   115  		return nil, os.ErrInvalid
   116  	}
   117  	err = w.rawConn.Read(w.readFunc)
   118  	if err != nil {
   119  		return
   120  	}
   121  	if w.readErr != nil {
   122  		if w.readErr == io.EOF {
   123  			return nil, io.EOF
   124  		}
   125  		return nil, E.Cause(w.readErr, "raw read")
   126  	}
   127  	buffer = w.buffer
   128  	w.buffer = nil
   129  	return
   130  }
   131  
   132  var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
   133  
   134  type syscallPacketReadWaiter struct {
   135  	rawConn  syscall.RawConn
   136  	readErr  error
   137  	readFrom M.Socksaddr
   138  	readFunc func(fd uintptr) (done bool)
   139  	hasData  bool
   140  	buffer   *buf.Buffer
   141  	options  N.ReadWaitOptions
   142  }
   143  
   144  func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool) {
   145  	if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn {
   146  		rawConn, err := syscallConn.SyscallConn()
   147  		if err == nil {
   148  			return &syscallPacketReadWaiter{rawConn: rawConn}, true
   149  		}
   150  	}
   151  	return nil, false
   152  }
   153  
   154  func (w *syscallPacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
   155  	w.options = options
   156  	w.readFunc = func(fd uintptr) (done bool) {
   157  		if !w.hasData {
   158  			w.hasData = true
   159  			// golang's internal/poll.FD.RawRead will Use a zero-byte read as a way to get notified when this
   160  			// socket is readable if we return false. So the `recvfrom` syscall will not block the system thread.
   161  			return false
   162  		}
   163  		buffer := w.options.NewPacketBuffer()
   164  		var readN int
   165  		var from windows.Sockaddr
   166  		readN, from, w.readErr = windows.Recvfrom(windows.Handle(fd), buffer.FreeBytes(), 0)
   167  		if readN > 0 {
   168  			buffer.Truncate(readN)
   169  			w.options.PostReturn(buffer)
   170  			w.buffer = buffer
   171  		} else {
   172  			buffer.Release()
   173  		}
   174  		if w.readErr == windows.WSAEWOULDBLOCK {
   175  			return false
   176  		}
   177  		if from != nil {
   178  			switch fromAddr := from.(type) {
   179  			case *windows.SockaddrInet4:
   180  				w.readFrom = M.SocksaddrFrom(netip.AddrFrom4(fromAddr.Addr), uint16(fromAddr.Port))
   181  			case *windows.SockaddrInet6:
   182  				w.readFrom = M.SocksaddrFrom(netip.AddrFrom16(fromAddr.Addr), uint16(fromAddr.Port)).Unwrap()
   183  			}
   184  		}
   185  		w.hasData = false
   186  		return true
   187  	}
   188  	return false
   189  }
   190  
   191  func (w *syscallPacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
   192  	if w.readFunc == nil {
   193  		return nil, M.Socksaddr{}, os.ErrInvalid
   194  	}
   195  	err = w.rawConn.Read(w.readFunc)
   196  	if err != nil {
   197  		return
   198  	}
   199  	if w.readErr != nil {
   200  		err = E.Cause(w.readErr, "raw read")
   201  		return
   202  	}
   203  	buffer = w.buffer
   204  	w.buffer = nil
   205  	destination = w.readFrom
   206  	return
   207  }