github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/fsimpl/fuse/dev_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package fuse
    16  
    17  import (
    18  	"fmt"
    19  	"math/rand"
    20  	"testing"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    23  	"github.com/SagerNet/gvisor/pkg/sentry/fsimpl/testutil"
    24  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/kernel/auth"
    26  	"github.com/SagerNet/gvisor/pkg/sentry/vfs"
    27  	"github.com/SagerNet/gvisor/pkg/syserror"
    28  	"github.com/SagerNet/gvisor/pkg/usermem"
    29  	"github.com/SagerNet/gvisor/pkg/waiter"
    30  )
    31  
    32  // echoTestOpcode is the Opcode used during testing. The server used in tests
    33  // will simply echo the payload back with the appropriate headers.
    34  const echoTestOpcode linux.FUSEOpcode = 1000
    35  
    36  // TestFUSECommunication tests that the communication layer between the Sentry and the
    37  // FUSE server daemon works as expected.
    38  func TestFUSECommunication(t *testing.T) {
    39  	s := setup(t)
    40  	defer s.Destroy()
    41  
    42  	k := kernel.KernelFromContext(s.Ctx)
    43  	creds := auth.CredentialsFromContext(s.Ctx)
    44  
    45  	// Create test cases with different number of concurrent clients and servers.
    46  	testCases := []struct {
    47  		Name              string
    48  		NumClients        int
    49  		NumServers        int
    50  		MaxActiveRequests uint64
    51  	}{
    52  		{
    53  			Name:              "SingleClientSingleServer",
    54  			NumClients:        1,
    55  			NumServers:        1,
    56  			MaxActiveRequests: maxActiveRequestsDefault,
    57  		},
    58  		{
    59  			Name:              "SingleClientMultipleServers",
    60  			NumClients:        1,
    61  			NumServers:        10,
    62  			MaxActiveRequests: maxActiveRequestsDefault,
    63  		},
    64  		{
    65  			Name:              "MultipleClientsSingleServer",
    66  			NumClients:        10,
    67  			NumServers:        1,
    68  			MaxActiveRequests: maxActiveRequestsDefault,
    69  		},
    70  		{
    71  			Name:              "MultipleClientsMultipleServers",
    72  			NumClients:        10,
    73  			NumServers:        10,
    74  			MaxActiveRequests: maxActiveRequestsDefault,
    75  		},
    76  		{
    77  			Name:              "RequestCapacityFull",
    78  			NumClients:        10,
    79  			NumServers:        1,
    80  			MaxActiveRequests: 1,
    81  		},
    82  		{
    83  			Name:              "RequestCapacityContinuouslyFull",
    84  			NumClients:        100,
    85  			NumServers:        2,
    86  			MaxActiveRequests: 2,
    87  		},
    88  	}
    89  
    90  	for _, testCase := range testCases {
    91  		t.Run(testCase.Name, func(t *testing.T) {
    92  			conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
    93  			if err != nil {
    94  				t.Fatalf("newTestConnection: %v", err)
    95  			}
    96  
    97  			clientsDone := make([]chan struct{}, testCase.NumClients)
    98  			serversDone := make([]chan struct{}, testCase.NumServers)
    99  			serversKill := make([]chan struct{}, testCase.NumServers)
   100  
   101  			// FUSE clients.
   102  			for i := 0; i < testCase.NumClients; i++ {
   103  				clientsDone[i] = make(chan struct{})
   104  				go func(i int) {
   105  					fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
   106  				}(i)
   107  			}
   108  
   109  			// FUSE servers.
   110  			for j := 0; j < testCase.NumServers; j++ {
   111  				serversDone[j] = make(chan struct{})
   112  				serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
   113  				go func(j int) {
   114  					fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
   115  				}(j)
   116  			}
   117  
   118  			// Tear down.
   119  			//
   120  			// Make sure all the clients are done.
   121  			for i := 0; i < testCase.NumClients; i++ {
   122  				<-clientsDone[i]
   123  			}
   124  
   125  			// Kill any server that is potentially waiting.
   126  			for j := 0; j < testCase.NumServers; j++ {
   127  				serversKill[j] <- struct{}{}
   128  			}
   129  
   130  			// Make sure all the servers are done.
   131  			for j := 0; j < testCase.NumServers; j++ {
   132  				<-serversDone[j]
   133  			}
   134  		})
   135  	}
   136  }
   137  
   138  // CallTest makes a request to the server and blocks the invoking
   139  // goroutine until a server responds with a response. Doesn't block
   140  // a kernel.Task. Analogous to Connection.Call but used for testing.
   141  func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
   142  	conn.fd.mu.Lock()
   143  
   144  	// Wait until we're certain that a new request can be processed.
   145  	for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
   146  		conn.fd.mu.Unlock()
   147  		select {
   148  		case <-conn.fd.fullQueueCh:
   149  		}
   150  		conn.fd.mu.Lock()
   151  	}
   152  
   153  	fut, err := conn.callFutureLocked(t, r) // No task given.
   154  	conn.fd.mu.Unlock()
   155  
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	// Resolve the response.
   161  	//
   162  	// Block without a task.
   163  	select {
   164  	case <-fut.ch:
   165  	}
   166  
   167  	// A response is ready. Resolve and return it.
   168  	return fut.getResponse(), nil
   169  }
   170  
   171  // ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
   172  // device. However, it does so by - not blocking the task that is calling - and
   173  // instead just waits on a channel. The behaviour is essentially the same as
   174  // DeviceFD.Read except it guarantees that the task is not blocked.
   175  func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
   176  	var err error
   177  	var n, total int64
   178  
   179  	dev := fd.Impl().(*DeviceFD)
   180  
   181  	// Register for notifications.
   182  	w, ch := waiter.NewChannelEntry(nil)
   183  	dev.EventRegister(&w, waiter.ReadableEvents)
   184  	for {
   185  		// Issue the request and break out if it completes with anything other than
   186  		// "would block".
   187  		n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
   188  		total += n
   189  		if err != syserror.ErrWouldBlock {
   190  			break
   191  		}
   192  
   193  		// Wait for a notification that we should retry.
   194  		// Emulate the blocking for when no requests are available
   195  		select {
   196  		case <-ch:
   197  		case <-killServer:
   198  			// Server killed by the main program.
   199  			return 0, true, nil
   200  		}
   201  	}
   202  
   203  	dev.EventUnregister(&w)
   204  	return total, false, err
   205  }
   206  
   207  // fuseClientRun emulates all the actions of a normal FUSE request. It creates
   208  // a header, a payload, calls the server, waits for the response, and processes
   209  // the response.
   210  func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
   211  	defer func() { clientDone <- struct{}{} }()
   212  
   213  	tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
   214  	clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
   215  	if err != nil {
   216  		t.Fatal(err)
   217  	}
   218  	testObj := &testPayload{
   219  		data: rand.Uint32(),
   220  	}
   221  
   222  	req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
   223  
   224  	// Queue up a request.
   225  	// Analogous to Call except it doesn't block on the task.
   226  	resp, err := CallTest(conn, clientTask, req, pid)
   227  	if err != nil {
   228  		t.Fatalf("CallTaskNonBlock failed: %v", err)
   229  	}
   230  
   231  	if err = resp.Error(); err != nil {
   232  		t.Fatalf("Server responded with an error: %v", err)
   233  	}
   234  
   235  	var respTestPayload testPayload
   236  	if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
   237  		t.Fatalf("Unmarshalling payload error: %v", err)
   238  	}
   239  
   240  	if resp.hdr.Unique != req.hdr.Unique {
   241  		t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
   242  			req.hdr.Unique, resp.hdr.Unique)
   243  	}
   244  
   245  	if respTestPayload.data != testObj.data {
   246  		t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
   247  	}
   248  
   249  }
   250  
   251  // fuseServerRun creates a task and emulates all the actions of a simple FUSE server
   252  // that simply reads a request and echos the same struct back as a response using the
   253  // appropriate headers.
   254  func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
   255  	defer func() { serverDone <- struct{}{} }()
   256  
   257  	// Create the tasks that the server will be using.
   258  	tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
   259  	var readPayload testPayload
   260  
   261  	serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  
   266  	// Read the request.
   267  	for {
   268  		inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
   269  		payloadLen := uint32(readPayload.SizeBytes())
   270  
   271  		// The raed buffer must meet some certain size criteria.
   272  		buffSize := inHdrLen + payloadLen
   273  		if buffSize < linux.FUSE_MIN_READ_BUFFER {
   274  			buffSize = linux.FUSE_MIN_READ_BUFFER
   275  		}
   276  		inBuf := make([]byte, buffSize)
   277  		inIOseq := usermem.BytesIOSequence(inBuf)
   278  
   279  		n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
   280  		if err != nil {
   281  			t.Fatalf("Read failed :%v", err)
   282  		}
   283  
   284  		// Server should shut down. No new requests are going to be made.
   285  		if serverKilled {
   286  			break
   287  		}
   288  
   289  		if n <= 0 {
   290  			t.Fatalf("Read read no bytes")
   291  		}
   292  
   293  		var readFUSEHeaderIn linux.FUSEHeaderIn
   294  		readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
   295  		readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
   296  
   297  		if readFUSEHeaderIn.Opcode != echoTestOpcode {
   298  			t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
   299  		}
   300  
   301  		// Write the response.
   302  		outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
   303  		outBuf := make([]byte, outHdrLen+payloadLen)
   304  		outHeader := linux.FUSEHeaderOut{
   305  			Len:    outHdrLen + payloadLen,
   306  			Error:  0,
   307  			Unique: readFUSEHeaderIn.Unique,
   308  		}
   309  
   310  		// Echo the payload back.
   311  		outHeader.MarshalUnsafe(outBuf[:outHdrLen])
   312  		readPayload.MarshalUnsafe(outBuf[outHdrLen:])
   313  		outIOseq := usermem.BytesIOSequence(outBuf)
   314  
   315  		_, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
   316  		if err != nil {
   317  			t.Fatalf("Write failed :%v", err)
   318  		}
   319  	}
   320  }