github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/sysfs/poll_windows_test.go (about)

     1  package sysfs
     2  
     3  import (
     4  	"net"
     5  	"os"
     6  	"syscall"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/wasilibs/wazerox/experimental/sys"
    11  	"github.com/wasilibs/wazerox/internal/testing/require"
    12  )
    13  
    14  func TestPoll_Windows(t *testing.T) {
    15  	type result struct {
    16  		n   int
    17  		err sys.Errno
    18  	}
    19  
    20  	pollToChannel := func(fd uintptr, timeoutMillis int32, ch chan result) {
    21  		r := result{}
    22  		fds := []pollFd{{fd: fd, events: _POLLIN}}
    23  		r.n, r.err = _poll(fds, timeoutMillis)
    24  		ch <- r
    25  		close(ch)
    26  	}
    27  
    28  	t.Run("poll returns sys.ENOSYS when n == 0 and timeoutMillis is negative", func(t *testing.T) {
    29  		n, errno := _poll(nil, -1)
    30  		require.Equal(t, -1, n)
    31  		require.EqualErrno(t, sys.ENOSYS, errno)
    32  	})
    33  
    34  	t.Run("peekNamedPipe should report the correct state of incoming data in the pipe", func(t *testing.T) {
    35  		r, w, err := os.Pipe()
    36  		require.NoError(t, err)
    37  		rh := syscall.Handle(r.Fd())
    38  		wh := syscall.Handle(w.Fd())
    39  
    40  		// Ensure the pipe has no data.
    41  		n, err := peekNamedPipe(rh)
    42  		require.Zero(t, err)
    43  		require.Zero(t, n)
    44  
    45  		// Write to the channel.
    46  		msg, err := syscall.ByteSliceFromString("test\n")
    47  		require.NoError(t, err)
    48  		_, err = syscall.Write(wh, msg)
    49  		require.NoError(t, err)
    50  
    51  		// Ensure the pipe has data.
    52  		n, err = peekNamedPipe(rh)
    53  		require.Zero(t, err)
    54  		require.Equal(t, 6, int(n))
    55  	})
    56  
    57  	t.Run("peekPipes should return an error on invalid handle", func(t *testing.T) {
    58  		fds := []pollFd{{fd: uintptr(syscall.InvalidHandle)}}
    59  		_, err := peekPipes(fds)
    60  		require.EqualErrno(t, sys.EBADF, err)
    61  	})
    62  
    63  	t.Run("peekAll should return an error on invalid handle", func(t *testing.T) {
    64  		fds := []pollFd{{fd: uintptr(syscall.InvalidHandle)}}
    65  		_, _, err := peekAll(fds, nil)
    66  		require.EqualErrno(t, sys.EBADF, err)
    67  	})
    68  
    69  	t.Run("poll should return successfully with a regular file", func(t *testing.T) {
    70  		f, err := os.CreateTemp(t.TempDir(), "test")
    71  		require.NoError(t, err)
    72  		defer f.Close()
    73  
    74  		fds := []pollFd{{fd: f.Fd()}}
    75  
    76  		n, errno := _poll(fds, 0)
    77  		require.Zero(t, errno)
    78  		require.Equal(t, 1, n)
    79  	})
    80  
    81  	t.Run("peekAll should return successfully with a pipe", func(t *testing.T) {
    82  		r, w, err := os.Pipe()
    83  		require.NoError(t, err)
    84  		defer r.Close()
    85  		defer w.Close()
    86  
    87  		fds := []pollFd{{fd: r.Fd()}}
    88  
    89  		npipes, nsockets, errno := peekAll(fds, nil)
    90  		require.Zero(t, errno)
    91  		require.Equal(t, 0, npipes)
    92  		require.Equal(t, 0, nsockets)
    93  
    94  		w.Write([]byte("wazero"))
    95  		npipes, nsockets, errno = peekAll(fds, nil)
    96  		require.Zero(t, errno)
    97  		require.Equal(t, 1, npipes)
    98  		require.Equal(t, 0, nsockets)
    99  	})
   100  
   101  	t.Run("peekAll should return successfully with a socket", func(t *testing.T) {
   102  		listen, err := net.Listen("tcp", "127.0.0.1:0")
   103  		require.NoError(t, err)
   104  		defer listen.Close()
   105  
   106  		conn, err := listen.(*net.TCPListener).SyscallConn()
   107  		require.NoError(t, err)
   108  
   109  		fds := []pollFd{}
   110  		conn.Control(func(fd uintptr) {
   111  			fds = append(fds, pollFd{fd: fd, events: _POLLIN})
   112  		})
   113  
   114  		npipes, nsockets, errno := peekAll(nil, fds)
   115  		require.Zero(t, errno)
   116  		require.Equal(t, 0, npipes)
   117  		require.Equal(t, 0, nsockets)
   118  
   119  		tcpAddr, err := net.ResolveTCPAddr("tcp", listen.Addr().String())
   120  		require.NoError(t, err)
   121  		tcp, err := net.DialTCP("tcp", nil, tcpAddr)
   122  		require.NoError(t, err)
   123  		tcp.Write([]byte("wazero"))
   124  
   125  		conn.Control(func(fd uintptr) {
   126  			fds[0].fd = fd
   127  		})
   128  		npipes, nsockets, errno = peekAll(nil, fds)
   129  		require.Zero(t, errno)
   130  		require.Equal(t, 0, npipes)
   131  		require.Equal(t, 1, nsockets)
   132  	})
   133  
   134  	t.Run("poll should return immediately when duration is zero (no data)", func(t *testing.T) {
   135  		r, _, err := os.Pipe()
   136  		require.NoError(t, err)
   137  		fds := []pollFd{{fd: r.Fd(), events: _POLLIN}}
   138  		n, err := _poll(fds, 0)
   139  		require.Zero(t, err)
   140  		require.Zero(t, n)
   141  	})
   142  
   143  	t.Run("poll should return immediately when duration is zero (data)", func(t *testing.T) {
   144  		r, w, err := os.Pipe()
   145  		require.NoError(t, err)
   146  		fds := []pollFd{{fd: r.Fd(), events: _POLLIN}}
   147  		wh := syscall.Handle(w.Fd())
   148  
   149  		// Write to the channel immediately.
   150  		msg, err := syscall.ByteSliceFromString("test\n")
   151  		require.NoError(t, err)
   152  		_, err = syscall.Write(wh, msg)
   153  		require.NoError(t, err)
   154  
   155  		// Verify that the write is reported.
   156  		n, err := _poll(fds, 0)
   157  		require.Zero(t, err)
   158  		require.Equal(t, 1, n)
   159  	})
   160  
   161  	t.Run("poll should wait forever when duration is nil (no writes)", func(t *testing.T) {
   162  		r, _, err := os.Pipe()
   163  		require.NoError(t, err)
   164  
   165  		ch := make(chan result, 1)
   166  		go pollToChannel(r.Fd(), -1, ch)
   167  
   168  		// Wait a little, then ensure no writes occurred.
   169  		<-time.After(500 * time.Millisecond)
   170  		require.Equal(t, 0, len(ch))
   171  	})
   172  
   173  	t.Run("poll should wait forever when duration is nil", func(t *testing.T) {
   174  		r, w, err := os.Pipe()
   175  		require.NoError(t, err)
   176  		wh := syscall.Handle(w.Fd())
   177  
   178  		ch := make(chan result, 1)
   179  		go pollToChannel(r.Fd(), -1, ch)
   180  
   181  		// Wait a little, then ensure no writes occurred.
   182  		<-time.After(100 * time.Millisecond)
   183  		require.Equal(t, 0, len(ch))
   184  
   185  		// Write a message to the pipe.
   186  		msg, err := syscall.ByteSliceFromString("test\n")
   187  		require.NoError(t, err)
   188  		_, err = syscall.Write(wh, msg)
   189  		require.NoError(t, err)
   190  
   191  		// Ensure that the write occurs (panic after an arbitrary timeout).
   192  		select {
   193  		case <-time.After(500 * time.Millisecond):
   194  			t.Fatal("unreachable!")
   195  		case r := <-ch:
   196  			require.Zero(t, r.err)
   197  			require.NotEqual(t, 0, r.n)
   198  		}
   199  	})
   200  
   201  	t.Run("poll should wait for the given duration", func(t *testing.T) {
   202  		r, w, err := os.Pipe()
   203  		require.NoError(t, err)
   204  		wh := syscall.Handle(w.Fd())
   205  
   206  		ch := make(chan result, 1)
   207  		go pollToChannel(r.Fd(), 500, ch)
   208  
   209  		// Wait a little, then ensure no writes occurred.
   210  		<-time.After(100 * time.Millisecond)
   211  		require.Equal(t, 0, len(ch))
   212  
   213  		// Write a message to the pipe.
   214  		msg, err := syscall.ByteSliceFromString("test\n")
   215  		require.NoError(t, err)
   216  		_, err = syscall.Write(wh, msg)
   217  		require.NoError(t, err)
   218  
   219  		// Ensure that the write occurs before the timer expires.
   220  		select {
   221  		case <-time.After(500 * time.Millisecond):
   222  			panic("no data!")
   223  		case r := <-ch:
   224  			require.Zero(t, r.err)
   225  			require.Equal(t, 1, r.n)
   226  		}
   227  	})
   228  
   229  	t.Run("poll should timeout after the given duration", func(t *testing.T) {
   230  		r, _, err := os.Pipe()
   231  		require.NoError(t, err)
   232  
   233  		ch := make(chan result, 1)
   234  		go pollToChannel(r.Fd(), 200, ch)
   235  
   236  		// Ensure that the timer has expired.
   237  		res := <-ch
   238  		require.Zero(t, res.err)
   239  		require.Zero(t, res.n)
   240  	})
   241  
   242  	t.Run("poll should return when a write occurs before the given duration", func(t *testing.T) {
   243  		r, w, err := os.Pipe()
   244  		require.NoError(t, err)
   245  		wh := syscall.Handle(w.Fd())
   246  
   247  		ch := make(chan result, 1)
   248  		go pollToChannel(r.Fd(), 800, ch)
   249  
   250  		<-time.After(300 * time.Millisecond)
   251  		require.Equal(t, 0, len(ch))
   252  
   253  		msg, err := syscall.ByteSliceFromString("test\n")
   254  		require.NoError(t, err)
   255  		_, err = syscall.Write(wh, msg)
   256  		require.NoError(t, err)
   257  
   258  		res := <-ch
   259  		require.Zero(t, res.err)
   260  		require.Equal(t, 1, res.n)
   261  	})
   262  
   263  	t.Run("poll should return when a regular file is given", func(t *testing.T) {
   264  		f, err := os.CreateTemp(t.TempDir(), "ex")
   265  		defer f.Close()
   266  		require.NoError(t, err)
   267  		fds := []pollFd{{fd: f.Fd(), events: _POLLIN}}
   268  		n, errno := _poll(fds, 0)
   269  		require.Zero(t, errno)
   270  		require.Equal(t, 1, n)
   271  	})
   272  }