gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/link/sharedmem/sharedmem_server_test.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //go:build linux 16 // +build linux 17 18 package sharedmem_server_test 19 20 import ( 21 "fmt" 22 "io" 23 "net" 24 "net/http" 25 "os" 26 "strings" 27 "syscall" 28 "testing" 29 30 "golang.org/x/sync/errgroup" 31 "golang.org/x/sys/unix" 32 "gvisor.dev/gvisor/pkg/log" 33 "gvisor.dev/gvisor/pkg/refs" 34 "gvisor.dev/gvisor/pkg/tcpip" 35 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 36 "gvisor.dev/gvisor/pkg/tcpip/header" 37 "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" 38 "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem" 39 "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" 40 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 41 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 42 "gvisor.dev/gvisor/pkg/tcpip/stack" 43 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 44 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 45 ) 46 47 const ( 48 localLinkAddr = "\xde\xad\xbe\xef\x56\x78" 49 remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34" 50 serverPort = 10001 51 52 defaultMTU = 65536 53 defaultBufferSize = 1500 54 55 // qDisc options 56 numQueues = 1 57 queueLen = 1000 58 ) 59 60 var ( 61 localIPv4Address = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01")) 62 remoteIPv4Address = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02")) 63 ) 64 65 type stackOptions struct { 66 ep stack.LinkEndpoint 67 addr tcpip.Address 68 enablePacketLogs bool 69 } 70 71 func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) { 72 st := stack.New(stack.Options{ 73 NetworkProtocols: []stack.NetworkProtocolFactory{ 74 ipv4.NewProtocolWithOptions(ipv4.Options{ 75 AllowExternalLoopbackTraffic: true, 76 }), 77 ipv6.NewProtocolWithOptions(ipv6.Options{ 78 AllowExternalLoopbackTraffic: true, 79 }), 80 }, 81 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 82 }) 83 nicID := tcpip.NICID(1) 84 ep := stackOpts.ep 85 if stackOpts.enablePacketLogs { 86 ep = sniffer.New(stackOpts.ep) 87 } 88 qDisc := fifo.New(ep, int(numQueues), int(queueLen)) 89 opts := stack.NICOptions{ 90 Name: "eth0", 91 QDisc: qDisc, 92 } 93 if err := st.CreateNICWithOptions(nicID, ep, opts); err != nil { 94 return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err) 95 } 96 97 // Add Protocol Address. 98 protocolNum := ipv4.ProtocolNumber 99 routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}} 100 if stackOpts.addr.Len() == 16 { 101 routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}} 102 protocolNum = ipv6.ProtocolNumber 103 } 104 protocolAddr := tcpip.ProtocolAddress{ 105 Protocol: protocolNum, 106 AddressWithPrefix: stackOpts.addr.WithPrefix(), 107 } 108 if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 109 return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err) 110 } 111 112 // Setup route table. 113 st.SetRouteTable(routeTable) 114 115 return st, nil 116 } 117 118 func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { 119 ep, err := sharedmem.New(sharedmem.Options{ 120 MTU: defaultMTU, 121 BufferSize: defaultBufferSize, 122 LinkAddress: localLinkAddr, 123 TX: qPair.TXQueueConfig(), 124 RX: qPair.RXQueueConfig(), 125 PeerFD: peerFD, 126 }) 127 if err != nil { 128 return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) 129 } 130 st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address, enablePacketLogs: false}) 131 if err != nil { 132 return nil, fmt.Errorf("failed to create client stack: %s", err) 133 } 134 return st, nil 135 } 136 137 func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) { 138 ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{ 139 MTU: defaultMTU, 140 BufferSize: defaultBufferSize, 141 LinkAddress: remoteLinkAddr, 142 TX: qPair.TXQueueConfig(), 143 RX: qPair.RXQueueConfig(), 144 PeerFD: peerFD, 145 }) 146 if err != nil { 147 return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err) 148 } 149 st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address, enablePacketLogs: false}) 150 if err != nil { 151 return nil, fmt.Errorf("failed to create client stack: %s", err) 152 } 153 return st, nil 154 } 155 156 type testContext struct { 157 clientStk *stack.Stack 158 serverStk *stack.Stack 159 peerFDs [2]int 160 } 161 162 func newTestContext(t *testing.T) *testContext { 163 peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0) 164 if err != nil { 165 t.Fatalf("failed to create peerFDs: %s", err) 166 } 167 q, err := sharedmem.NewQueuePair(sharedmem.QueueOptions{}) 168 if err != nil { 169 t.Fatalf("failed to create sharedmem queue: %s", err) 170 } 171 clientStack, err := newClientStack(t, q, peerFDs[0]) 172 if err != nil { 173 q.Close() 174 unix.Close(peerFDs[0]) 175 unix.Close(peerFDs[1]) 176 t.Fatalf("failed to create client stack: %s", err) 177 } 178 serverStack, err := newServerStack(t, q, peerFDs[1]) 179 if err != nil { 180 q.Close() 181 unix.Close(peerFDs[0]) 182 unix.Close(peerFDs[1]) 183 clientStack.Close() 184 t.Fatalf("failed to create server stack: %s", err) 185 } 186 return &testContext{ 187 clientStk: clientStack, 188 serverStk: serverStack, 189 peerFDs: peerFDs, 190 } 191 } 192 193 func (ctx *testContext) cleanup() { 194 ctx.clientStk.RemoveNIC(tcpip.NICID(1)) 195 ctx.serverStk.RemoveNIC(tcpip.NICID(1)) 196 unix.Close(ctx.peerFDs[0]) 197 unix.Close(ctx.peerFDs[1]) 198 ctx.clientStk.Close() 199 ctx.serverStk.Close() 200 ctx.clientStk.Wait() 201 ctx.serverStk.Wait() 202 } 203 204 func makeRequest(serverAddr tcpip.FullAddress, clientStk *stack.Stack) (*http.Response, error) { 205 dialFunc := func(address, protocol string) (net.Conn, error) { 206 return gonet.DialTCP(clientStk, serverAddr, ipv4.ProtocolNumber) 207 } 208 httpClient := &http.Client{ 209 Transport: &http.Transport{ 210 Dial: dialFunc, 211 }, 212 } 213 // Close idle "keep alive" connections. If any connections remain open after 214 // a test ends, DoLeakCheck() will erroneously detect leaked packets. 215 defer httpClient.CloseIdleConnections() 216 serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(serverAddr.Addr.AsSlice()), serverAddr.Port) 217 response, err := httpClient.Get(serverURL) 218 return response, err 219 } 220 221 func TestServerRoundTrip(t *testing.T) { 222 ctx := newTestContext(t) 223 defer ctx.cleanup() 224 listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} 225 l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) 226 if err != nil { 227 t.Fatalf("failed to start TCP Listener: %s", err) 228 } 229 defer l.Close() 230 var responseString = strings.Repeat("response", 8<<10) 231 go func() { 232 http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 233 w.Write([]byte(responseString)) 234 })) 235 }() 236 237 response, err := makeRequest(listenAddr, ctx.clientStk) 238 if err != nil { 239 t.Fatalf("httpClient.Get(\"/\") failed: %s", err) 240 } 241 if got, want := response.StatusCode, http.StatusOK; got != want { 242 t.Fatalf("unexpected status code got: %d, want: %d", got, want) 243 } 244 body, err := io.ReadAll(response.Body) 245 if err != nil { 246 t.Fatalf("io.ReadAll(response.Body) failed: %s", err) 247 } 248 response.Body.Close() 249 if got, want := string(body), responseString; got != want { 250 t.Fatalf("unexpected response got: %s, want: %s", got, want) 251 } 252 } 253 254 func TestServerRoundTripStress(t *testing.T) { 255 ctx := newTestContext(t) 256 defer ctx.cleanup() 257 listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} 258 l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) 259 if err != nil { 260 t.Fatalf("failed to start TCP Listener: %s", err) 261 } 262 defer l.Close() 263 var responseString = strings.Repeat("response", 8<<10) 264 go func() { 265 http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 266 w.Write([]byte(responseString)) 267 })) 268 }() 269 270 var errs errgroup.Group 271 for i := 0; i < 1000; i++ { 272 errs.Go(func() error { 273 response, err := makeRequest(listenAddr, ctx.clientStk) 274 if err != nil { 275 return fmt.Errorf("httpClient.Get(\"/\") failed: %s", err) 276 } 277 if got, want := response.StatusCode, http.StatusOK; got != want { 278 return fmt.Errorf("unexpected status code got: %d, want: %d", got, want) 279 } 280 body, err := io.ReadAll(response.Body) 281 if err != nil { 282 return fmt.Errorf("io.ReadAll(response.Body) failed: %s", err) 283 } 284 response.Body.Close() 285 if got, want := string(body), responseString; got != want { 286 return fmt.Errorf("unexpected response got: %s, want: %s", got, want) 287 } 288 log.Infof("worker: read %d bytes", len(body)) 289 return nil 290 }) 291 } 292 if err := errs.Wait(); err != nil { 293 t.Fatalf("request failed: %s", err) 294 } 295 } 296 297 func TestServerBulkTransfer(t *testing.T) { 298 var payloadSizes = []int{ 299 512 << 20, // 512 MiB 300 1024 << 20, // 1 GiB 301 2048 << 20, // 2 GiB 302 } 303 304 for _, payloadSize := range payloadSizes { 305 t.Run(fmt.Sprintf("%d bytes", payloadSize), func(t *testing.T) { 306 ctx := newTestContext(t) 307 defer ctx.cleanup() 308 listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort} 309 l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber) 310 if err != nil { 311 t.Fatalf("failed to start TCP Listener: %s", err) 312 } 313 defer l.Close() 314 315 const chunkSize = 4 << 20 // 4 MiB 316 var responseString = strings.Repeat("r", chunkSize) 317 go func() { 318 http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 319 for done := 0; done < payloadSize; { 320 n, err := w.Write([]byte(responseString)) 321 if err != nil { 322 log.Infof("failed to write response : %s", err) 323 return 324 } 325 done += n 326 } 327 })) 328 }() 329 330 response, err := makeRequest(listenAddr, ctx.clientStk) 331 if err != nil { 332 t.Fatalf("httpClient.Get(\"/\") failed: %s", err) 333 } 334 if got, want := response.StatusCode, http.StatusOK; got != want { 335 t.Fatalf("unexpected status code got: %d, want: %d", got, want) 336 } 337 n, err := io.Copy(io.Discard, response.Body) 338 if err != nil { 339 t.Fatalf("io.Copy(io.Discard, response.Body) failed: %s", err) 340 } 341 response.Body.Close() 342 if got, want := int(n), payloadSize; got != want { 343 t.Fatalf("unexpected response size got: %d, want: %d", got, want) 344 } 345 log.Infof("read %d bytes", n) 346 }) 347 } 348 349 } 350 351 func TestClientBulkTransfer(t *testing.T) { 352 var payloadSizes = []int{ 353 512 << 20, // 512 MiB 354 1024 << 20, // 1 GiB 355 2048 << 20, // 2 GiB 356 } 357 358 for _, payloadSize := range payloadSizes { 359 t.Run(fmt.Sprintf("%d bytes", payloadSize), func(t *testing.T) { 360 ctx := newTestContext(t) 361 defer ctx.cleanup() 362 listenAddr := tcpip.FullAddress{Addr: localIPv4Address, Port: serverPort} 363 l, err := gonet.ListenTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber) 364 if err != nil { 365 t.Fatalf("failed to start TCP Listener: %s", err) 366 } 367 defer l.Close() 368 const chunkSize = 4 << 20 // 4 MiB 369 var responseString = strings.Repeat("r", chunkSize) 370 go func() { 371 http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 372 for done := 0; done < payloadSize; { 373 n, err := w.Write([]byte(responseString)) 374 if err != nil { 375 log.Infof("failed to write response : %s", err) 376 return 377 } 378 done += n 379 } 380 })) 381 }() 382 383 response, err := makeRequest(listenAddr, ctx.serverStk) 384 if err != nil { 385 t.Fatalf("httpClient.Get(\"/\") failed: %s", err) 386 } 387 if err != nil { 388 t.Fatalf("httpClient.Get(\"/\") failed: %s", err) 389 } 390 if got, want := response.StatusCode, http.StatusOK; got != want { 391 t.Fatalf("unexpected status code got: %d, want: %d", got, want) 392 } 393 n, err := io.Copy(io.Discard, response.Body) 394 if err != nil { 395 t.Fatalf("io.Copy(io.Discard, response.Body) failed: %s", err) 396 } 397 response.Body.Close() 398 if got, want := int(n), payloadSize; got != want { 399 t.Fatalf("unexpected response size got: %d, want: %d", got, want) 400 } 401 log.Infof("read %d bytes", n) 402 }) 403 } 404 } 405 406 func TestMain(m *testing.M) { 407 refs.SetLeakMode(refs.LeaksPanic) 408 code := m.Run() 409 refs.DoLeakCheck() 410 os.Exit(code) 411 }