gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/lisafs/sock_test.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package lisafs 16 17 import ( 18 "bytes" 19 "math/rand" 20 "reflect" 21 "testing" 22 23 "golang.org/x/sys/unix" 24 "gvisor.dev/gvisor/pkg/marshal" 25 "gvisor.dev/gvisor/pkg/sync" 26 "gvisor.dev/gvisor/pkg/unet" 27 ) 28 29 func runSocketTest(t *testing.T, fun1 func(*sockCommunicator), fun2 func(*sockCommunicator)) { 30 sock1, sock2, err := unet.SocketPair(false) 31 if err != nil { 32 t.Fatalf("socketpair got err %v expected nil", err) 33 } 34 defer sock1.Close() 35 defer sock2.Close() 36 37 var testWg sync.WaitGroup 38 testWg.Add(2) 39 40 go func() { 41 fun1(newSockComm(sock1)) 42 testWg.Done() 43 }() 44 45 go func() { 46 fun2(newSockComm(sock2)) 47 testWg.Done() 48 }() 49 50 testWg.Wait() 51 } 52 53 func TestReadWrite(t *testing.T) { 54 // Create random data to send. 55 n := 10000 56 data := make([]byte, n) 57 if _, err := rand.Read(data); err != nil { 58 t.Fatalf("rand.Read(data) failed: %v", err) 59 } 60 61 runSocketTest(t, func(comm *sockCommunicator) { 62 // Scatter that data into two parts using Iovecs while sending. 63 mid := n / 2 64 if err := writeTo(comm.sock, [][]byte{data[:mid], data[mid:]}, n, nil); err != nil { 65 t.Errorf("writeTo socket failed: %v", err) 66 } 67 }, func(comm *sockCommunicator) { 68 gotData := make([]byte, n) 69 if _, err := readFrom(comm.sock, gotData, 0); err != nil { 70 t.Fatalf("reading from socket failed: %v", err) 71 } 72 73 // Make sure we got the right data. 74 if res := bytes.Compare(data, gotData); res != 0 { 75 t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) 76 } 77 }) 78 } 79 80 func TestFDDonation(t *testing.T) { 81 n := 10 82 data := make([]byte, n) 83 if _, err := rand.Read(data); err != nil { 84 t.Fatalf("rand.Read(data) failed: %v", err) 85 } 86 87 // Try donating FDs to these files. 88 path1 := "/dev/null" 89 path2 := "/dev" 90 path3 := "/dev/random" 91 92 runSocketTest(t, func(comm *sockCommunicator) { 93 devNullFD, err := unix.Open(path1, unix.O_RDONLY, 0) 94 defer unix.Close(devNullFD) 95 if err != nil { 96 t.Fatalf("open(%s) failed: %v", path1, err) 97 } 98 devFD, err := unix.Open(path2, unix.O_RDONLY, 0) 99 defer unix.Close(devFD) 100 if err != nil { 101 t.Fatalf("open(%s) failed: %v", path2, err) 102 } 103 devRandomFD, err := unix.Open(path3, unix.O_RDONLY, 0) 104 defer unix.Close(devRandomFD) 105 if err != nil { 106 t.Fatalf("open(%s) failed: %v", path2, err) 107 } 108 if err := writeTo(comm.sock, [][]byte{data}, n, []int{devNullFD, devFD, devRandomFD}); err != nil { 109 t.Errorf("writeTo socket failed: %v", err) 110 } 111 }, func(comm *sockCommunicator) { 112 gotData := make([]byte, n) 113 fds, err := readFrom(comm.sock, gotData, 3) 114 if err != nil { 115 t.Fatalf("reading from socket failed: %v", err) 116 } 117 defer closeFDs(fds[:]) 118 119 if res := bytes.Compare(data, gotData); res != 0 { 120 t.Errorf("data received differs from data sent, want = %v, got = %v", data, gotData) 121 } 122 123 if len(fds) != 3 { 124 t.Fatalf("wanted 3 FD, got %d", len(fds)) 125 } 126 127 // Check that the FDs actually point to the correct file. 128 compareFDWithFile(t, fds[0], path1) 129 compareFDWithFile(t, fds[1], path2) 130 compareFDWithFile(t, fds[2], path3) 131 }) 132 } 133 134 func compareFDWithFile(t *testing.T, fd int, path string) { 135 var want unix.Stat_t 136 if err := unix.Stat(path, &want); err != nil { 137 t.Fatalf("stat(%s) failed: %v", path, err) 138 } 139 140 var got unix.Stat_t 141 if err := unix.Fstat(fd, &got); err != nil { 142 t.Fatalf("fstat on donated FD failed: %v", err) 143 } 144 145 if got.Ino != want.Ino || got.Dev != want.Dev { 146 t.Errorf("FD does not point to %s, want = %+v, got = %+v", path, want, got) 147 } 148 } 149 150 func testSndMsg(comm *sockCommunicator, m MID, msg marshal.Marshallable) error { 151 var payloadLen uint32 152 if msg != nil { 153 payloadLen = uint32(msg.SizeBytes()) 154 msg.MarshalUnsafe(comm.PayloadBuf(payloadLen)) 155 } 156 return comm.sndPrepopulatedMsg(m, payloadLen, nil) 157 } 158 159 func TestSndRcvMessage(t *testing.T) { 160 req := &MsgSimple{} 161 req.Randomize() 162 reqM := MID(1) 163 164 // Create a massive random response. 165 var resp MsgDynamic 166 resp.Randomize(100) 167 respM := MID(2) 168 169 runSocketTest(t, func(comm *sockCommunicator) { 170 if err := testSndMsg(comm, reqM, req); err != nil { 171 t.Errorf("writeMessageTo failed: %v", err) 172 } 173 checkMessageReceive(t, comm, respM, &resp) 174 }, func(comm *sockCommunicator) { 175 checkMessageReceive(t, comm, reqM, req) 176 if err := testSndMsg(comm, respM, &resp); err != nil { 177 t.Errorf("writeMessageTo failed: %v", err) 178 } 179 }) 180 } 181 182 func TestSndRcvMessageNoPayload(t *testing.T) { 183 reqM := MID(1) 184 respM := MID(2) 185 runSocketTest(t, func(comm *sockCommunicator) { 186 if err := testSndMsg(comm, reqM, nil); err != nil { 187 t.Errorf("writeMessageTo failed: %v", err) 188 } 189 checkMessageReceive(t, comm, respM, nil) 190 }, func(comm *sockCommunicator) { 191 checkMessageReceive(t, comm, reqM, nil) 192 if err := testSndMsg(comm, respM, nil); err != nil { 193 t.Errorf("writeMessageTo failed: %v", err) 194 } 195 }) 196 } 197 198 func checkMessageReceive(t *testing.T, comm *sockCommunicator, wantM MID, wantMsg marshal.Marshallable) { 199 gotM, payloadLen, err := comm.rcvMsg(0) 200 if err != nil { 201 t.Fatalf("readMessageFrom failed: %v", err) 202 } 203 if gotM != wantM { 204 t.Errorf("got incorrect message ID: got = %d, want = %d", gotM, wantM) 205 } 206 if wantMsg == nil { 207 if payloadLen != 0 { 208 t.Errorf("no payload expect but got %d bytes", payloadLen) 209 } 210 } else { 211 gotMsg := reflect.New(reflect.ValueOf(wantMsg).Elem().Type()).Interface().(marshal.Marshallable) 212 gotMsg.UnmarshalUnsafe(comm.PayloadBuf(payloadLen)) 213 if !reflect.DeepEqual(wantMsg, gotMsg) { 214 t.Errorf("msg differs: want = %+v, got = %+v", wantMsg, gotMsg) 215 } 216 } 217 }