gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/lisafs/sock_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lisafs
    16  
    17  import (
    18  	"bytes"
    19  	"math/rand"
    20  	"reflect"
    21  	"testing"
    22  
    23  	"golang.org/x/sys/unix"
    24  	"gvisor.dev/gvisor/pkg/marshal"
    25  	"gvisor.dev/gvisor/pkg/sync"
    26  	"gvisor.dev/gvisor/pkg/unet"
    27  )
    28  
    29  func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) {
    30  	sock1, sock2, err := unet.SocketPair(false)
    31  	if err != nil {
    32  		t.Fatalf("socketpair got err %v expected nil", err)
    33  	}
    34  	defer sock1.Close()
    35  	defer sock2.Close()
    36  
    37  	var testWg sync.WaitGroup
    38  	testWg.Add(2)
    39  
    40  	go func() {
    41  		fun1(newSockComm(sock1))
    42  		testWg.Done()
    43  	}()
    44  
    45  	go func() {
    46  		fun2(newSockComm(sock2))
    47  		testWg.Done()
    48  	}()
    49  
    50  	testWg.Wait()
    51  }
    52  
    53  func TestReadWrite(t *testing.T) {
    54  	// Create random data to send.
    55  	n := 10000
    56  	data := make([]byte, n)
    57  	if _, err := rand.Read(data); err != nil {
    58  		t.Fatalf("rand.Read(data) failed: %v", err)
    59  	}
    60  
    61  	runSocketTest(t, func(comm *sockCommunicator) {
    62  		// Scatter that data into two parts using Iovecs while sending.
    63  		mid := n / 2
    64  		if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil {
    65  			t.Errorf("writeTo socket failed: %v", err)
    66  		}
    67  	}, func(comm *sockCommunicator) {
    68  		gotData := make([]byte, n)
    69  		if _, err := readFrom(comm.sock, gotData, 0); err != nil {
    70  			t.Fatalf("reading from socket failed: %v", err)
    71  		}
    72  
    73  		// Make sure we got the right data.
    74  		if res := bytes.Compare(data, gotData); res != 0 {
    75  			t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
    76  		}
    77  	})
    78  }
    79  
    80  func TestFDDonation(t *testing.T) {
    81  	n := 10
    82  	data := make([]byte, n)
    83  	if _, err := rand.Read(data); err != nil {
    84  		t.Fatalf("rand.Read(data) failed: %v", err)
    85  	}
    86  
    87  	// Try donating FDs to these files.
    88  	path1 := "/dev/null"
    89  	path2 := "/dev"
    90  	path3 := "/dev/random"
    91  
    92  	runSocketTest(t, func(comm *sockCommunicator) {
    93  		devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0)
    94  		defer unix.Close(devNullFD)
    95  		if err != nil {
    96  			t.Fatalf("open(%s) failed: %v", path1, err)
    97  		}
    98  		devFD, err := unix.Open(path2, unix.O_RDONLY, 0)
    99  		defer unix.Close(devFD)
   100  		if err != nil {
   101  			t.Fatalf("open(%s) failed: %v", path2, err)
   102  		}
   103  		devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0)
   104  		defer unix.Close(devRandomFD)
   105  		if err != nil {
   106  			t.Fatalf("open(%s) failed: %v", path2, err)
   107  		}
   108  		if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil {
   109  			t.Errorf("writeTo socket failed: %v", err)
   110  		}
   111  	}, func(comm *sockCommunicator) {
   112  		gotData := make([]byte, n)
   113  		fds, err := readFrom(comm.sock, gotData, 3)
   114  		if err != nil {
   115  			t.Fatalf("reading from socket failed: %v", err)
   116  		}
   117  		defer closeFDs(fds[:])
   118  
   119  		if res := bytes.Compare(data, gotData); res != 0 {
   120  			t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData)
   121  		}
   122  
   123  		if len(fds) != 3 {
   124  			t.Fatalf("wanted 3 FD, got %d", len(fds))
   125  		}
   126  
   127  		// Check that the FDs actually point to the correct file.
   128  		compareFDWithFile(t, fds[0], path1)
   129  		compareFDWithFile(t, fds[1], path2)
   130  		compareFDWithFile(t, fds[2], path3)
   131  	})
   132  }
   133  
   134  func compareFDWithFile(t *testing.T, fd int, path string) {
   135  	var want unix.Stat_t
   136  	if err := unix.Stat(path, &want); err != nil {
   137  		t.Fatalf("stat(%s) failed: %v", path, err)
   138  	}
   139  
   140  	var got unix.Stat_t
   141  	if err := unix.Fstat(fd, &got); err != nil {
   142  		t.Fatalf("fstat on donated FD failed: %v", err)
   143  	}
   144  
   145  	if got.Ino != want.Ino || got.Dev != want.Dev {
   146  		t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got)
   147  	}
   148  }
   149  
   150  func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error {
   151  	var payloadLen uint32
   152  	if msg != nil {
   153  		payloadLen = uint32(msg.SizeBytes())
   154  		msg.MarshalUnsafe(comm.PayloadBuf(payloadLen))
   155  	}
   156  	return comm.sndPrepopulatedMsg(m, payloadLen, nil)
   157  }
   158  
   159  func TestSndRcvMessage(t *testing.T) {
   160  	req := &MsgSimple{}
   161  	req.Randomize()
   162  	reqM := MID(1)
   163  
   164  	// Create a massive random response.
   165  	var resp MsgDynamic
   166  	resp.Randomize(100)
   167  	respM := MID(2)
   168  
   169  	runSocketTest(t, func(comm *sockCommunicator) {
   170  		if err := testSndMsg(comm, reqM, req); err != nil {
   171  			t.Errorf("writeMessageTo failed: %v", err)
   172  		}
   173  		checkMessageReceive(t, comm, respM, &resp)
   174  	}, func(comm *sockCommunicator) {
   175  		checkMessageReceive(t, comm, reqM, req)
   176  		if err := testSndMsg(comm, respM, &resp); err != nil {
   177  			t.Errorf("writeMessageTo failed: %v", err)
   178  		}
   179  	})
   180  }
   181  
   182  func TestSndRcvMessageNoPayload(t *testing.T) {
   183  	reqM := MID(1)
   184  	respM := MID(2)
   185  	runSocketTest(t, func(comm *sockCommunicator) {
   186  		if err := testSndMsg(comm, reqM, nil); err != nil {
   187  			t.Errorf("writeMessageTo failed: %v", err)
   188  		}
   189  		checkMessageReceive(t, comm, respM, nil)
   190  	}, func(comm *sockCommunicator) {
   191  		checkMessageReceive(t, comm, reqM, nil)
   192  		if err := testSndMsg(comm, respM, nil); err != nil {
   193  			t.Errorf("writeMessageTo failed: %v", err)
   194  		}
   195  	})
   196  }
   197  
   198  func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) {
   199  	gotM, payloadLen, err := comm.rcvMsg(0)
   200  	if err != nil {
   201  		t.Fatalf("readMessageFrom failed: %v", err)
   202  	}
   203  	if gotM != wantM {
   204  		t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM)
   205  	}
   206  	if wantMsg == nil {
   207  		if payloadLen != 0 {
   208  			t.Errorf("no payload expect but got %d bytes", payloadLen)
   209  		}
   210  	} else {
   211  		gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable)
   212  		gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen))
   213  		if !reflect.DeepEqual(wantMsg, gotMsg) {
   214  			t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg)
   215  		}
   216  	}
   217  }