gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/link/sharedmem/sharedmem_server_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  package sharedmem_server_test
    19  
    20  import (
    21  	"fmt"
    22  	"io"
    23  	"net"
    24  	"net/http"
    25  	"os"
    26  	"strings"
    27  	"syscall"
    28  	"testing"
    29  
    30  	"golang.org/x/sync/errgroup"
    31  	"golang.org/x/sys/unix"
    32  	"gvisor.dev/gvisor/pkg/log"
    33  	"gvisor.dev/gvisor/pkg/refs"
    34  	"gvisor.dev/gvisor/pkg/tcpip"
    35  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    36  	"gvisor.dev/gvisor/pkg/tcpip/header"
    37  	"gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
    38  	"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem"
    39  	"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
    40  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    41  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    42  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    43  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    44  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    45  )
    46  
    47  const (
    48  	localLinkAddr  = "\xde\xad\xbe\xef\x56\x78"
    49  	remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
    50  	serverPort     = 10001
    51  
    52  	defaultMTU        = 65536
    53  	defaultBufferSize = 1500
    54  
    55  	// qDisc options
    56  	numQueues = 1
    57  	queueLen  = 1000
    58  )
    59  
    60  var (
    61  	localIPv4Address  = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01"))
    62  	remoteIPv4Address = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02"))
    63  )
    64  
    65  type stackOptions struct {
    66  	ep               stack.LinkEndpoint
    67  	addr             tcpip.Address
    68  	enablePacketLogs bool
    69  }
    70  
    71  func newStackWithOptions(stackOpts stackOptions) (*stack.Stack, error) {
    72  	st := stack.New(stack.Options{
    73  		NetworkProtocols: []stack.NetworkProtocolFactory{
    74  			ipv4.NewProtocolWithOptions(ipv4.Options{
    75  				AllowExternalLoopbackTraffic: true,
    76  			}),
    77  			ipv6.NewProtocolWithOptions(ipv6.Options{
    78  				AllowExternalLoopbackTraffic: true,
    79  			}),
    80  		},
    81  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
    82  	})
    83  	nicID := tcpip.NICID(1)
    84  	ep := stackOpts.ep
    85  	if stackOpts.enablePacketLogs {
    86  		ep = sniffer.New(stackOpts.ep)
    87  	}
    88  	qDisc := fifo.New(ep, int(numQueues), int(queueLen))
    89  	opts := stack.NICOptions{
    90  		Name:  "eth0",
    91  		QDisc: qDisc,
    92  	}
    93  	if err := st.CreateNICWithOptions(nicID, ep, opts); err != nil {
    94  		return nil, fmt.Errorf("method CreateNICWithOptions(%d, _, %v) failed: %s", nicID, opts, err)
    95  	}
    96  
    97  	// Add Protocol Address.
    98  	protocolNum := ipv4.ProtocolNumber
    99  	routeTable := []tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}
   100  	if stackOpts.addr.Len() == 16 {
   101  		routeTable = []tcpip.Route{{Destination: header.IPv6EmptySubnet, NIC: nicID}}
   102  		protocolNum = ipv6.ProtocolNumber
   103  	}
   104  	protocolAddr := tcpip.ProtocolAddress{
   105  		Protocol:          protocolNum,
   106  		AddressWithPrefix: stackOpts.addr.WithPrefix(),
   107  	}
   108  	if err := st.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
   109  		return nil, fmt.Errorf("AddProtocolAddress(%d, %v, {}): %s", nicID, protocolAddr, err)
   110  	}
   111  
   112  	// Setup route table.
   113  	st.SetRouteTable(routeTable)
   114  
   115  	return st, nil
   116  }
   117  
   118  func newClientStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
   119  	ep, err := sharedmem.New(sharedmem.Options{
   120  		MTU:         defaultMTU,
   121  		BufferSize:  defaultBufferSize,
   122  		LinkAddress: localLinkAddr,
   123  		TX:          qPair.TXQueueConfig(),
   124  		RX:          qPair.RXQueueConfig(),
   125  		PeerFD:      peerFD,
   126  	})
   127  	if err != nil {
   128  		return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
   129  	}
   130  	st, err := newStackWithOptions(stackOptions{ep: ep, addr: localIPv4Address, enablePacketLogs: false})
   131  	if err != nil {
   132  		return nil, fmt.Errorf("failed to create client stack: %s", err)
   133  	}
   134  	return st, nil
   135  }
   136  
   137  func newServerStack(t *testing.T, qPair *sharedmem.QueuePair, peerFD int) (*stack.Stack, error) {
   138  	ep, err := sharedmem.NewServerEndpoint(sharedmem.Options{
   139  		MTU:         defaultMTU,
   140  		BufferSize:  defaultBufferSize,
   141  		LinkAddress: remoteLinkAddr,
   142  		TX:          qPair.TXQueueConfig(),
   143  		RX:          qPair.RXQueueConfig(),
   144  		PeerFD:      peerFD,
   145  	})
   146  	if err != nil {
   147  		return nil, fmt.Errorf("failed to create sharedmem endpoint: %s", err)
   148  	}
   149  	st, err := newStackWithOptions(stackOptions{ep: ep, addr: remoteIPv4Address, enablePacketLogs: false})
   150  	if err != nil {
   151  		return nil, fmt.Errorf("failed to create client stack: %s", err)
   152  	}
   153  	return st, nil
   154  }
   155  
   156  type testContext struct {
   157  	clientStk *stack.Stack
   158  	serverStk *stack.Stack
   159  	peerFDs   [2]int
   160  }
   161  
   162  func newTestContext(t *testing.T) *testContext {
   163  	peerFDs, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET|syscall.SOCK_NONBLOCK, 0)
   164  	if err != nil {
   165  		t.Fatalf("failed to create peerFDs: %s", err)
   166  	}
   167  	q, err := sharedmem.NewQueuePair(sharedmem.QueueOptions{})
   168  	if err != nil {
   169  		t.Fatalf("failed to create sharedmem queue: %s", err)
   170  	}
   171  	clientStack, err := newClientStack(t, q, peerFDs[0])
   172  	if err != nil {
   173  		q.Close()
   174  		unix.Close(peerFDs[0])
   175  		unix.Close(peerFDs[1])
   176  		t.Fatalf("failed to create client stack: %s", err)
   177  	}
   178  	serverStack, err := newServerStack(t, q, peerFDs[1])
   179  	if err != nil {
   180  		q.Close()
   181  		unix.Close(peerFDs[0])
   182  		unix.Close(peerFDs[1])
   183  		clientStack.Close()
   184  		t.Fatalf("failed to create server stack: %s", err)
   185  	}
   186  	return &testContext{
   187  		clientStk: clientStack,
   188  		serverStk: serverStack,
   189  		peerFDs:   peerFDs,
   190  	}
   191  }
   192  
   193  func (ctx *testContext) cleanup() {
   194  	ctx.clientStk.RemoveNIC(tcpip.NICID(1))
   195  	ctx.serverStk.RemoveNIC(tcpip.NICID(1))
   196  	unix.Close(ctx.peerFDs[0])
   197  	unix.Close(ctx.peerFDs[1])
   198  	ctx.clientStk.Close()
   199  	ctx.serverStk.Close()
   200  	ctx.clientStk.Wait()
   201  	ctx.serverStk.Wait()
   202  }
   203  
   204  func makeRequest(serverAddr tcpip.FullAddress, clientStk *stack.Stack) (*http.Response, error) {
   205  	dialFunc := func(address, protocol string) (net.Conn, error) {
   206  		return gonet.DialTCP(clientStk, serverAddr, ipv4.ProtocolNumber)
   207  	}
   208  	httpClient := &http.Client{
   209  		Transport: &http.Transport{
   210  			Dial: dialFunc,
   211  		},
   212  	}
   213  	// Close idle "keep alive" connections. If any connections remain open after
   214  	// a test ends, DoLeakCheck() will erroneously detect leaked packets.
   215  	defer httpClient.CloseIdleConnections()
   216  	serverURL := fmt.Sprintf("http://[%s]:%d/", net.IP(serverAddr.Addr.AsSlice()), serverAddr.Port)
   217  	response, err := httpClient.Get(serverURL)
   218  	return response, err
   219  }
   220  
   221  func TestServerRoundTrip(t *testing.T) {
   222  	ctx := newTestContext(t)
   223  	defer ctx.cleanup()
   224  	listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
   225  	l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
   226  	if err != nil {
   227  		t.Fatalf("failed to start TCP Listener: %s", err)
   228  	}
   229  	defer l.Close()
   230  	var responseString = strings.Repeat("response", 8<<10)
   231  	go func() {
   232  		http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   233  			w.Write([]byte(responseString))
   234  		}))
   235  	}()
   236  
   237  	response, err := makeRequest(listenAddr, ctx.clientStk)
   238  	if err != nil {
   239  		t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
   240  	}
   241  	if got, want := response.StatusCode, http.StatusOK; got != want {
   242  		t.Fatalf("unexpected status code got: %d, want: %d", got, want)
   243  	}
   244  	body, err := io.ReadAll(response.Body)
   245  	if err != nil {
   246  		t.Fatalf("io.ReadAll(response.Body) failed: %s", err)
   247  	}
   248  	response.Body.Close()
   249  	if got, want := string(body), responseString; got != want {
   250  		t.Fatalf("unexpected response got: %s, want: %s", got, want)
   251  	}
   252  }
   253  
   254  func TestServerRoundTripStress(t *testing.T) {
   255  	ctx := newTestContext(t)
   256  	defer ctx.cleanup()
   257  	listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
   258  	l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
   259  	if err != nil {
   260  		t.Fatalf("failed to start TCP Listener: %s", err)
   261  	}
   262  	defer l.Close()
   263  	var responseString = strings.Repeat("response", 8<<10)
   264  	go func() {
   265  		http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   266  			w.Write([]byte(responseString))
   267  		}))
   268  	}()
   269  
   270  	var errs errgroup.Group
   271  	for i := 0; i < 1000; i++ {
   272  		errs.Go(func() error {
   273  			response, err := makeRequest(listenAddr, ctx.clientStk)
   274  			if err != nil {
   275  				return fmt.Errorf("httpClient.Get(\"/\") failed: %s", err)
   276  			}
   277  			if got, want := response.StatusCode, http.StatusOK; got != want {
   278  				return fmt.Errorf("unexpected status code got: %d, want: %d", got, want)
   279  			}
   280  			body, err := io.ReadAll(response.Body)
   281  			if err != nil {
   282  				return fmt.Errorf("io.ReadAll(response.Body) failed: %s", err)
   283  			}
   284  			response.Body.Close()
   285  			if got, want := string(body), responseString; got != want {
   286  				return fmt.Errorf("unexpected response got: %s, want: %s", got, want)
   287  			}
   288  			log.Infof("worker: read %d bytes", len(body))
   289  			return nil
   290  		})
   291  	}
   292  	if err := errs.Wait(); err != nil {
   293  		t.Fatalf("request failed: %s", err)
   294  	}
   295  }
   296  
   297  func TestServerBulkTransfer(t *testing.T) {
   298  	var payloadSizes = []int{
   299  		512 << 20,  // 512 MiB
   300  		1024 << 20, // 1 GiB
   301  		2048 << 20, // 2 GiB
   302  	}
   303  
   304  	for _, payloadSize := range payloadSizes {
   305  		t.Run(fmt.Sprintf("%d bytes", payloadSize), func(t *testing.T) {
   306  			ctx := newTestContext(t)
   307  			defer ctx.cleanup()
   308  			listenAddr := tcpip.FullAddress{Addr: remoteIPv4Address, Port: serverPort}
   309  			l, err := gonet.ListenTCP(ctx.serverStk, listenAddr, ipv4.ProtocolNumber)
   310  			if err != nil {
   311  				t.Fatalf("failed to start TCP Listener: %s", err)
   312  			}
   313  			defer l.Close()
   314  
   315  			const chunkSize = 4 << 20 // 4 MiB
   316  			var responseString = strings.Repeat("r", chunkSize)
   317  			go func() {
   318  				http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   319  					for done := 0; done < payloadSize; {
   320  						n, err := w.Write([]byte(responseString))
   321  						if err != nil {
   322  							log.Infof("failed to write response : %s", err)
   323  							return
   324  						}
   325  						done += n
   326  					}
   327  				}))
   328  			}()
   329  
   330  			response, err := makeRequest(listenAddr, ctx.clientStk)
   331  			if err != nil {
   332  				t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
   333  			}
   334  			if got, want := response.StatusCode, http.StatusOK; got != want {
   335  				t.Fatalf("unexpected status code got: %d, want: %d", got, want)
   336  			}
   337  			n, err := io.Copy(io.Discard, response.Body)
   338  			if err != nil {
   339  				t.Fatalf("io.Copy(io.Discard, response.Body) failed: %s", err)
   340  			}
   341  			response.Body.Close()
   342  			if got, want := int(n), payloadSize; got != want {
   343  				t.Fatalf("unexpected response size got: %d, want: %d", got, want)
   344  			}
   345  			log.Infof("read %d bytes", n)
   346  		})
   347  	}
   348  
   349  }
   350  
   351  func TestClientBulkTransfer(t *testing.T) {
   352  	var payloadSizes = []int{
   353  		512 << 20,  // 512 MiB
   354  		1024 << 20, // 1 GiB
   355  		2048 << 20, // 2 GiB
   356  	}
   357  
   358  	for _, payloadSize := range payloadSizes {
   359  		t.Run(fmt.Sprintf("%d bytes", payloadSize), func(t *testing.T) {
   360  			ctx := newTestContext(t)
   361  			defer ctx.cleanup()
   362  			listenAddr := tcpip.FullAddress{Addr: localIPv4Address, Port: serverPort}
   363  			l, err := gonet.ListenTCP(ctx.clientStk, listenAddr, ipv4.ProtocolNumber)
   364  			if err != nil {
   365  				t.Fatalf("failed to start TCP Listener: %s", err)
   366  			}
   367  			defer l.Close()
   368  			const chunkSize = 4 << 20 // 4 MiB
   369  			var responseString = strings.Repeat("r", chunkSize)
   370  			go func() {
   371  				http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   372  					for done := 0; done < payloadSize; {
   373  						n, err := w.Write([]byte(responseString))
   374  						if err != nil {
   375  							log.Infof("failed to write response : %s", err)
   376  							return
   377  						}
   378  						done += n
   379  					}
   380  				}))
   381  			}()
   382  
   383  			response, err := makeRequest(listenAddr, ctx.serverStk)
   384  			if err != nil {
   385  				t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
   386  			}
   387  			if err != nil {
   388  				t.Fatalf("httpClient.Get(\"/\") failed: %s", err)
   389  			}
   390  			if got, want := response.StatusCode, http.StatusOK; got != want {
   391  				t.Fatalf("unexpected status code got: %d, want: %d", got, want)
   392  			}
   393  			n, err := io.Copy(io.Discard, response.Body)
   394  			if err != nil {
   395  				t.Fatalf("io.Copy(io.Discard, response.Body) failed: %s", err)
   396  			}
   397  			response.Body.Close()
   398  			if got, want := int(n), payloadSize; got != want {
   399  				t.Fatalf("unexpected response size got: %d, want: %d", got, want)
   400  			}
   401  			log.Infof("read %d bytes", n)
   402  		})
   403  	}
   404  }
   405  
   406  func TestMain(m *testing.M) {
   407  	refs.SetLeakMode(refs.LeaksPanic)
   408  	code := m.Run()
   409  	refs.DoLeakCheck()
   410  	os.Exit(code)
   411  }