gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/lisafs/connection_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package connection_test
    16  
    17  import (
    18  	"reflect"
    19  	"testing"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"gvisor.dev/gvisor/pkg/abi/linux"
    23  	"gvisor.dev/gvisor/pkg/lisafs"
    24  	"gvisor.dev/gvisor/pkg/sync"
    25  	"gvisor.dev/gvisor/pkg/unet"
    26  )
    27  
    28  const (
    29  	dynamicMsgID = lisafs.Channel + 1
    30  	versionMsgID = dynamicMsgID + 1
    31  )
    32  
    33  var handlers = [...]lisafs.RPCHandler{
    34  	lisafs.Error:   lisafs.ErrorHandler,
    35  	lisafs.Mount:   lisafs.MountHandler,
    36  	lisafs.Channel: lisafs.ChannelHandler,
    37  	dynamicMsgID:   dynamicMsgHandler,
    38  	versionMsgID:   versionHandler,
    39  }
    40  
    41  // testServer implements lisafs.ServerImpl.
    42  type testServer struct {
    43  	lisafs.Server
    44  }
    45  
    46  var _ lisafs.ServerImpl = (*testServer)(nil)
    47  
    48  type testControlFD struct {
    49  	lisafs.ControlFD
    50  	lisafs.ControlFDImpl
    51  }
    52  
    53  func (fd *testControlFD) FD() *lisafs.ControlFD {
    54  	return &fd.ControlFD
    55  }
    56  
    57  func (fd *testControlFD) Close() {}
    58  
    59  // Mount implements lisafs.Mount.
    60  func (s *testServer) Mount(c *lisafs.Connection, mountNode *lisafs.Node) (*lisafs.ControlFD, linux.Statx, int, error) {
    61  	dummyRoot := &testControlFD{}
    62  	mountNode.IncRef() // Ref is transferred to ControlFD.
    63  	dummyRoot.Init(c, mountNode, linux.ModeDirectory, dummyRoot)
    64  	return dummyRoot.FD(), linux.Statx{Mode: linux.S_IFDIR}, -1, nil
    65  }
    66  
    67  // MaxMessageSize implements lisafs.MaxMessageSize.
    68  func (s *testServer) MaxMessageSize() uint32 {
    69  	return lisafs.MaxMessageSize()
    70  }
    71  
    72  // SupportedMessages implements lisafs.ServerImpl.SupportedMessages.
    73  func (s *testServer) SupportedMessages() []lisafs.MID {
    74  	return []lisafs.MID{
    75  		lisafs.Mount,
    76  		lisafs.Channel,
    77  		dynamicMsgID,
    78  		versionMsgID,
    79  	}
    80  }
    81  
    82  func runServerClient(t testing.TB, clientFn func(c *lisafs.Client)) {
    83  	serverSocket, clientSocket, err := unet.SocketPair(false)
    84  	if err != nil {
    85  		t.Fatalf("socketpair got err %v expected nil", err)
    86  	}
    87  
    88  	ts := &testServer{}
    89  	ts.Init(ts, lisafs.ServerOpts{})
    90  	ts.SetHandlers(handlers[:])
    91  	conn, err := ts.CreateConnection(serverSocket, "/" /* mountPath */, false /* readonly */)
    92  	if err != nil {
    93  		t.Fatalf("starting connection failed: %v", err)
    94  		return
    95  	}
    96  	ts.StartConnection(conn)
    97  
    98  	c, _, _, err := lisafs.NewClient(clientSocket)
    99  	if err != nil {
   100  		t.Fatalf("client creation failed: %v", err)
   101  	}
   102  	if err := c.StartChannels(); err != nil {
   103  		t.Fatalf("failed to start channels: %v", err)
   104  	}
   105  
   106  	clientFn(c)
   107  
   108  	c.Close() // This should trigger client and server shutdown.
   109  	ts.Wait()
   110  	ts.Server.Destroy()
   111  }
   112  
   113  // TestStartUp tests that the server and client can be started up correctly.
   114  func TestStartUp(t *testing.T) {
   115  	runServerClient(t, func(c *lisafs.Client) {
   116  		if c.IsSupported(lisafs.Error) {
   117  			t.Errorf("sending error messages should not be supported")
   118  		}
   119  	})
   120  }
   121  
   122  func TestUnsupportedMessage(t *testing.T) {
   123  	unsupportedM := lisafs.MID(len(handlers))
   124  	var em lisafs.EmptyMessage
   125  	runServerClient(t, func(c *lisafs.Client) {
   126  		if err := c.SndRcvMessage(unsupportedM, uint32(em.SizeBytes()), em.MarshalBytes, em.CheckedUnmarshal, nil, em.String, em.String); err != unix.EOPNOTSUPP {
   127  			t.Errorf("expected EOPNOTSUPP but got err: %v", err)
   128  		}
   129  	})
   130  }
   131  
   132  func dynamicMsgHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
   133  	var req lisafs.MsgDynamic
   134  	if _, ok := req.CheckedUnmarshal(comm.PayloadBuf(payloadLen)); !ok {
   135  		return 0, unix.EIO
   136  	}
   137  
   138  	// Just echo back the message.
   139  	respPayloadLen := uint32(req.SizeBytes())
   140  	req.MarshalBytes(comm.PayloadBuf(respPayloadLen))
   141  	return respPayloadLen, nil
   142  }
   143  
   144  // TestStress stress tests sending many messages from various goroutines.
   145  func TestStress(t *testing.T) {
   146  	runServerClient(t, func(c *lisafs.Client) {
   147  		concurrency := 8
   148  		numMsgPerGoroutine := 5000
   149  		var clientWg sync.WaitGroup
   150  		for i := 0; i < concurrency; i++ {
   151  			clientWg.Add(1)
   152  			go func() {
   153  				defer clientWg.Done()
   154  
   155  				for j := 0; j < numMsgPerGoroutine; j++ {
   156  					// Create a massive random message.
   157  					var req lisafs.MsgDynamic
   158  					req.Randomize(100)
   159  
   160  					var resp lisafs.MsgDynamic
   161  					if err := c.SndRcvMessage(dynamicMsgID, uint32(req.SizeBytes()), req.MarshalBytes, resp.CheckedUnmarshal, nil, req.String, resp.String); err != nil {
   162  						t.Errorf("SndRcvMessage: received unexpected error %v", err)
   163  						return
   164  					}
   165  					if !reflect.DeepEqual(&req, &resp) {
   166  						t.Errorf("response should be the same as request: request = %+v, response = %+v", req, resp)
   167  					}
   168  				}
   169  			}()
   170  		}
   171  
   172  		clientWg.Wait()
   173  	})
   174  }
   175  
   176  func versionHandler(c *lisafs.Connection, comm lisafs.Communicator, payloadLen uint32) (uint32, error) {
   177  	// To be fair, usually handlers will create their own objects and return a
   178  	// pointer to those. Might be tempting to reuse above variables, but don't.
   179  	var rv lisafs.P9Version
   180  	if _, ok := rv.CheckedUnmarshal(comm.PayloadBuf(payloadLen)); !ok {
   181  		return 0, unix.EIO
   182  	}
   183  
   184  	// Create a new response.
   185  	sv := lisafs.P9Version{
   186  		MSize:   rv.MSize,
   187  		Version: "9P2000.L.Google.11",
   188  	}
   189  	respPayloadLen := uint32(sv.SizeBytes())
   190  	sv.MarshalBytes(comm.PayloadBuf(respPayloadLen))
   191  	return respPayloadLen, nil
   192  }
   193  
   194  // BenchmarkSendRecv exists to compete against p9's BenchmarkSendRecvChannel.
   195  func BenchmarkSendRecv(b *testing.B) {
   196  	b.ReportAllocs()
   197  	sendV := lisafs.P9Version{
   198  		MSize:   1 << 20,
   199  		Version: "9P2000.L.Google.12",
   200  	}
   201  
   202  	var recvV lisafs.P9Version
   203  	runServerClient(b, func(c *lisafs.Client) {
   204  		for i := 0; i < b.N; i++ {
   205  			if err := c.SndRcvMessage(versionMsgID, uint32(sendV.SizeBytes()), sendV.MarshalBytes, recvV.CheckedUnmarshal, nil, sendV.String, recvV.String); err != nil {
   206  				b.Fatalf("unexpected error occurred: %v", err)
   207  			}
   208  		}
   209  	})
   210  }