github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/imports/wasi_snapshot_preview1/sock_test.go (about)

     1  package wasi_snapshot_preview1_test
     2  
     3  import (
     4  	"bytes"
     5  	"net"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/bananabytelabs/wazero"
    11  	"github.com/bananabytelabs/wazero/api"
    12  	experimentalsock "github.com/bananabytelabs/wazero/experimental/sock"
    13  	"github.com/bananabytelabs/wazero/internal/sys"
    14  	"github.com/bananabytelabs/wazero/internal/testing/require"
    15  	"github.com/bananabytelabs/wazero/internal/wasip1"
    16  	"github.com/bananabytelabs/wazero/internal/wasm"
    17  )
    18  
    19  func Test_sockAccept(t *testing.T) {
    20  	tests := []struct {
    21  		name          string
    22  		flags         uint16
    23  		expectedErrno wasip1.Errno
    24  		expectedLog   string
    25  		body          func(mod api.Module, log *bytes.Buffer, fd, connFd uintptr, tcp *net.TCPConn)
    26  	}{
    27  		{
    28  			name:          "sock_accept",
    29  			flags:         0,
    30  			expectedErrno: wasip1.ErrnoSuccess,
    31  			expectedLog: `
    32  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
    33  <== (fd=4,errno=ESUCCESS)
    34  `,
    35  		},
    36  		{
    37  			name:  "sock_accept (nonblock)",
    38  			flags: wasip1.FD_NONBLOCK,
    39  			expectedLog: `
    40  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=NONBLOCK)
    41  <== (fd=4,errno=ESUCCESS)
    42  `,
    43  		},
    44  	}
    45  
    46  	for _, tc := range tests {
    47  		t.Run(tc.name, func(t *testing.T) {
    48  			ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0))
    49  
    50  			mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig())
    51  			defer r.Close(testCtx)
    52  
    53  			// Dial the socket so that a call to accept doesn't hang.
    54  			tcpAddr := requireTCPListenerAddr(t, mod)
    55  			tcp, err := net.DialTCP("tcp", nil, tcpAddr)
    56  			require.NoError(t, err)
    57  			defer tcp.Close() //nolint
    58  
    59  			requireErrnoResult(t, tc.expectedErrno, mod, wasip1.SockAcceptName, uint64(sys.FdPreopen), uint64(tc.flags), 128)
    60  			connFd, _ := mod.Memory().ReadUint32Le(128)
    61  			require.Equal(t, uint32(4), connFd)
    62  
    63  			require.Equal(t, tc.expectedLog, "\n"+log.String())
    64  		})
    65  	}
    66  }
    67  
    68  func Test_sockShutdown(t *testing.T) {
    69  	tests := []struct {
    70  		name          string
    71  		flags         uint8
    72  		expectedErrno wasip1.Errno
    73  		expectedLog   string
    74  	}{
    75  		{
    76  			name:          "sock_shutdown",
    77  			flags:         wasip1.SD_WR | wasip1.SD_RD,
    78  			expectedErrno: wasip1.ErrnoSuccess,
    79  			expectedLog: `
    80  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
    81  <== (fd=4,errno=ESUCCESS)
    82  ==> wasi_snapshot_preview1.sock_shutdown(fd=4,how=RD|WR)
    83  <== errno=ESUCCESS
    84  `,
    85  		},
    86  		{
    87  			name:          "sock_shutdown: fail with no flags",
    88  			flags:         0,
    89  			expectedErrno: wasip1.ErrnoInval,
    90  			expectedLog: `
    91  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
    92  <== (fd=4,errno=ESUCCESS)
    93  ==> wasi_snapshot_preview1.sock_shutdown(fd=4,how=)
    94  <== errno=EINVAL
    95  `,
    96  		},
    97  	}
    98  
    99  	for _, tc := range tests {
   100  		t.Run(tc.name, func(t *testing.T) {
   101  			ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0))
   102  
   103  			mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig())
   104  			defer r.Close(testCtx)
   105  
   106  			// Dial the socket so that a call to accept doesn't hang.
   107  			tcpAddr := requireTCPListenerAddr(t, mod)
   108  			tcp, err := net.DialTCP("tcp", nil, tcpAddr)
   109  			require.NoError(t, err)
   110  			defer tcp.Close() //nolint
   111  
   112  			requireErrnoResult(t, wasip1.ErrnoSuccess, mod, wasip1.SockAcceptName, uint64(sys.FdPreopen), uint64(0), 128)
   113  			connFd, _ := mod.Memory().ReadUint32Le(128)
   114  			require.Equal(t, uint32(4), connFd)
   115  
   116  			// End of setup. Perform the test.
   117  			requireErrnoResult(t, tc.expectedErrno, mod, wasip1.SockShutdownName, uint64(connFd), uint64(tc.flags))
   118  
   119  			require.Equal(t, tc.expectedLog, "\n"+log.String())
   120  		})
   121  	}
   122  }
   123  
   124  func Test_sockRecv(t *testing.T) {
   125  	tests := []struct {
   126  		name           string
   127  		funcName       string
   128  		flags          uint8
   129  		expectedErrno  wasip1.Errno
   130  		expectedLog    string
   131  		initialMemory  []byte
   132  		iovsCount      uint64
   133  		expectedMemory []byte
   134  	}{
   135  		{
   136  			name:      "sock_recv",
   137  			iovsCount: 3,
   138  			initialMemory: []byte{
   139  				'?',         // `iovs` is after this
   140  				26, 0, 0, 0, // = iovs[0].offset
   141  				4, 0, 0, 0, // = iovs[0].length
   142  				31, 0, 0, 0, // = iovs[1].offset
   143  				0, 0, 0, 0, // = iovs[1].length == 0 !!
   144  				31, 0, 0, 0, // = iovs[2].offset
   145  				2, 0, 0, 0, // = iovs[2].length
   146  				'?',
   147  			},
   148  			expectedMemory: []byte{
   149  				'w', 'a', 'z', 'e', // iovs[0].length bytes
   150  				'?',      // iovs[2].offset is after this
   151  				'r', 'o', // iovs[2].length bytes
   152  				'?',        // resultNread is after this
   153  				6, 0, 0, 0, // sum(iovs[...].length) == length of "wazero"
   154  				0, 0, // flags
   155  				'?',
   156  			},
   157  			expectedLog: `
   158  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
   159  <== (fd=4,errno=ESUCCESS)
   160  ==> wasi_snapshot_preview1.sock_recv(fd=4,ri_data=1,ri_data_len=3,ri_flags=)
   161  <== (ro_datalen=6,ro_flags=,errno=ESUCCESS)
   162  `,
   163  		},
   164  
   165  		{
   166  			name:      "sock_recv (WAITALL)",
   167  			flags:     wasip1.RI_RECV_WAITALL,
   168  			iovsCount: 3,
   169  			initialMemory: []byte{
   170  				'?',         // `iovs` is after this
   171  				26, 0, 0, 0, // = iovs[0].offset
   172  				4, 0, 0, 0, // = iovs[0].length
   173  				31, 0, 0, 0, // = iovs[1].offset
   174  				0, 0, 0, 0, // = iovs[1].length == 0 !!
   175  				31, 0, 0, 0, // = iovs[2].offset
   176  				2, 0, 0, 0, // = iovs[2].length
   177  				'?',
   178  			},
   179  			expectedMemory: []byte{
   180  				'w', 'a', 'z', 'e', // iovs[0].length bytes
   181  				'?',      // iovs[2].offset is after this
   182  				'r', 'o', // iovs[2].length bytes
   183  				'?',        // resultNread is after this
   184  				6, 0, 0, 0, // sum(iovs[...].length) == length of "wazero"
   185  				0, 0, // flags
   186  				'?',
   187  			},
   188  
   189  			expectedLog: `
   190  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
   191  <== (fd=4,errno=ESUCCESS)
   192  ==> wasi_snapshot_preview1.sock_recv(fd=4,ri_data=1,ri_data_len=3,ri_flags=RECV_WAITALL)
   193  <== (ro_datalen=6,ro_flags=,errno=ESUCCESS)
   194  `,
   195  		},
   196  
   197  		{
   198  			name:      "sock_recv (PEEK)",
   199  			flags:     wasip1.RI_RECV_PEEK,
   200  			iovsCount: 3,
   201  			initialMemory: []byte{
   202  				'?',         // `iovs` is after this
   203  				26, 0, 0, 0, // = iovs[0].offset
   204  				4, 0, 0, 0, // = iovs[0].length
   205  				31, 0, 0, 0, // = iovs[1].offset
   206  				0, 0, 0, 0, // = iovs[1].length == 0 !!
   207  				31, 0, 0, 0, // = iovs[2].offset
   208  				2, 0, 0, 0, // = iovs[2].length
   209  				'?',
   210  			},
   211  			expectedMemory: []byte{
   212  				'w', 'a', 'z', 'e', // iovs[0].length bytes
   213  				'?', '?', '?', '?', // pad to 34
   214  				4, 0, 0, 0, // result.ro_datalen
   215  				0, 0, // result.ro_flags
   216  				'?',
   217  			},
   218  			expectedLog: `
   219  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
   220  <== (fd=4,errno=ESUCCESS)
   221  ==> wasi_snapshot_preview1.sock_recv(fd=4,ri_data=1,ri_data_len=3,ri_flags=RECV_PEEK)
   222  <== (ro_datalen=4,ro_flags=,errno=ESUCCESS)
   223  `,
   224  		},
   225  		{
   226  			name:  "sock_recv: fail with unknown flags",
   227  			flags: 42,
   228  			expectedLog: `
   229  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
   230  <== (fd=4,errno=ESUCCESS)
   231  ==> wasi_snapshot_preview1.sock_recv(fd=4,ri_data=1,ri_data_len=0,ri_flags=RECV_WAITALL)
   232  <== (ro_datalen=,ro_flags=,errno=ENOTSUP)
   233  `,
   234  			expectedErrno: wasip1.ErrnoNotsup,
   235  		},
   236  	}
   237  
   238  	for _, tc := range tests {
   239  		t.Run(tc.name, func(t *testing.T) {
   240  			ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0))
   241  
   242  			mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig())
   243  			defer r.Close(testCtx)
   244  
   245  			// Dial the socket so that a call to accept doesn't hang.
   246  			tcpAddr := requireTCPListenerAddr(t, mod)
   247  			tcp, err := net.DialTCP("tcp", nil, tcpAddr)
   248  			require.NoError(t, err)
   249  			defer tcp.Close() //nolint
   250  
   251  			requireErrnoResult(t, wasip1.ErrnoSuccess, mod, wasip1.SockAcceptName, uint64(sys.FdPreopen), uint64(0), 128)
   252  			connFd, _ := mod.Memory().ReadUint32Le(128)
   253  			require.Equal(t, uint32(4), connFd)
   254  
   255  			// End of setup. Perform the test.
   256  
   257  			write, err := tcp.Write([]byte("wazero"))
   258  			require.NoError(t, err)
   259  			require.NotEqual(t, 0, write)
   260  
   261  			iovs := uint32(1)             // arbitrary offset
   262  			resultRoDatalen := uint32(34) // arbitrary offset
   263  			expectedMemory := append(tc.initialMemory, tc.expectedMemory...)
   264  			maskMemory(t, mod, len(expectedMemory))
   265  
   266  			ok := mod.Memory().Write(0, tc.initialMemory)
   267  			require.True(t, ok)
   268  
   269  			// Special case this test: let us add a bit of delay
   270  			// to avoid EAGAIN.
   271  			if tc.flags == wasip1.RI_RECV_PEEK {
   272  				time.Sleep(500 * time.Millisecond)
   273  			}
   274  
   275  			requireErrnoResult(t, tc.expectedErrno, mod, wasip1.SockRecvName, uint64(connFd), uint64(iovs), tc.iovsCount, uint64(tc.flags), uint64(resultRoDatalen), uint64(resultRoDatalen+4))
   276  			require.Equal(t, tc.expectedLog, "\n"+log.String())
   277  
   278  			actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory)))
   279  			require.True(t, ok)
   280  			require.Equal(t, expectedMemory, actual)
   281  		})
   282  	}
   283  }
   284  
   285  func Test_sockSend(t *testing.T) {
   286  	tests := []struct {
   287  		name           string
   288  		funcName       string
   289  		flags          uint32
   290  		expectedErrno  wasip1.Errno
   291  		expectedLog    string
   292  		initialMemory  []byte
   293  		iovsCount      uint64
   294  		expectedMemory []byte
   295  	}{
   296  		{
   297  			name:      "sock_send",
   298  			iovsCount: 3,
   299  			initialMemory: []byte{
   300  				'?',         // `iovs` is after this
   301  				18, 0, 0, 0, // = iovs[0].offset
   302  				4, 0, 0, 0, // = iovs[0].length
   303  				23, 0, 0, 0, // = iovs[1].offset
   304  				2, 0, 0, 0, // = iovs[1].length
   305  				'?',                // iovs[0].offset is after this
   306  				'w', 'a', 'z', 'e', // iovs[0].length bytes
   307  				'?',      // iovs[1].offset is after this
   308  				'r', 'o', // iovs[1].length bytes
   309  				'?',
   310  			},
   311  			expectedMemory: []byte{
   312  				6, 0, 0, 0, // sum(iovs[...].length) == length of "wazero"
   313  				'?',
   314  			},
   315  
   316  			expectedLog: `
   317  ==> wasi_snapshot_preview1.sock_accept(fd=3,flags=)
   318  <== (fd=4,errno=ESUCCESS)
   319  ==> wasi_snapshot_preview1.sock_send(fd=4,si_data=1,si_data_len=2,si_flags=)
   320  <== (so_datalen=6,errno=ESUCCESS)
   321  `,
   322  		},
   323  	}
   324  
   325  	for _, tc := range tests {
   326  		t.Run(tc.name, func(t *testing.T) {
   327  			ctx := experimentalsock.WithConfig(testCtx, experimentalsock.NewConfig().WithTCPListener("127.0.0.1", 0))
   328  
   329  			mod, r, log := requireProxyModuleWithContext(ctx, t, wazero.NewModuleConfig())
   330  			defer r.Close(testCtx)
   331  
   332  			// Dial the socket so that a call to accept doesn't hang.
   333  			tcpAddr := requireTCPListenerAddr(t, mod)
   334  			tcp, err := net.DialTCP("tcp", nil, tcpAddr)
   335  			require.NoError(t, err)
   336  			defer tcp.Close() //nolint
   337  
   338  			requireErrnoResult(t, wasip1.ErrnoSuccess, mod, wasip1.SockAcceptName, uint64(sys.FdPreopen), uint64(0), 128)
   339  			connFd, _ := mod.Memory().ReadUint32Le(128)
   340  			require.Equal(t, uint32(4), connFd)
   341  
   342  			// End of setup. Perform the test.
   343  			iovs := uint32(1)             // arbitrary offset
   344  			iovsCount := uint32(2)        // The count of iovs
   345  			resultSoDatalen := uint32(26) // arbitrary offset
   346  			expectedMemory := append(tc.initialMemory, tc.expectedMemory...)
   347  
   348  			maskMemory(t, mod, len(expectedMemory))
   349  			ok := mod.Memory().Write(0, tc.initialMemory)
   350  			require.True(t, ok)
   351  
   352  			requireErrnoResult(t, wasip1.ErrnoSuccess, mod, wasip1.SockSendName, uint64(connFd), uint64(iovs), uint64(iovsCount), 0, uint64(resultSoDatalen))
   353  			require.Equal(t, tc.expectedLog, "\n"+log.String())
   354  
   355  			actual, ok := mod.Memory().Read(0, uint32(len(expectedMemory)))
   356  			require.True(t, ok)
   357  			require.Equal(t, expectedMemory, actual)
   358  
   359  			// Read back the value that was sent on the socket.
   360  			buf := make([]byte, 10)
   361  			read, err := tcp.Read(buf)
   362  			require.NoError(t, err)
   363  			require.NotEqual(t, 0, read)
   364  			// Sometimes `buf` is smaller than len("wazero").
   365  			require.True(t, strings.HasPrefix("wazero", string(buf[:read])))
   366  		})
   367  	}
   368  }
   369  
   370  type addr interface {
   371  	Addr() *net.TCPAddr
   372  }
   373  
   374  func requireTCPListenerAddr(t *testing.T, mod api.Module) *net.TCPAddr {
   375  	sock, ok := mod.(*wasm.ModuleInstance).Sys.FS().LookupFile(sys.FdPreopen)
   376  	require.True(t, ok)
   377  	return sock.File.(addr).Addr()
   378  }