gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/usermem/usermem_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package usermem 16 17 import ( 18 "bytes" 19 "fmt" 20 "slices" 21 "strings" 22 "testing" 23 24 "gvisor.dev/gvisor/pkg/context" 25 "gvisor.dev/gvisor/pkg/errors/linuxerr" 26 "gvisor.dev/gvisor/pkg/hostarch" 27 "gvisor.dev/gvisor/pkg/safemem" 28 ) 29 30 // newContext returns a context.Context that we can use in these tests (we 31 // can't use contexttest because it depends on usermem). 32 func newContext() context.Context { 33 return context.Background() 34 } 35 36 func newBytesIOString(s string) *BytesIO { 37 return &BytesIO{[]byte(s)} 38 } 39 40 func TestBytesIOCopyOutSuccess(t *testing.T) { 41 b := newBytesIOString("ABCDE") 42 n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) 43 if wantN := 3; n != wantN || err != nil { 44 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 45 } 46 if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { 47 t.Errorf("Bytes: got %q, wanted %q", got, want) 48 } 49 } 50 51 func TestBytesIOCopyOutFailure(t *testing.T) { 52 b := newBytesIOString("ABC") 53 n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) 54 if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr { 55 t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 56 } 57 if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { 58 t.Errorf("Bytes: got %q, wanted %q", got, want) 59 } 60 } 61 62 func TestBytesIOCopyInSuccess(t *testing.T) { 63 b := newBytesIOString("AfooE") 64 var dst [3]byte 65 n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) 66 if wantN := 3; n != wantN || err != nil { 67 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 68 } 69 if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { 70 t.Errorf("dst: got %q, wanted %q", got, want) 71 } 72 } 73 74 func TestBytesIOCopyInFailure(t *testing.T) { 75 b := newBytesIOString("Afo") 76 var dst [3]byte 77 n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) 78 if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr { 79 t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 80 } 81 if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { 82 t.Errorf("dst: got %q, wanted %q", got, want) 83 } 84 } 85 86 func TestBytesIOZeroOutSuccess(t *testing.T) { 87 b := newBytesIOString("ABCD") 88 n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) 89 if wantN := int64(2); n != wantN || err != nil { 90 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 91 } 92 if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { 93 t.Errorf("Bytes: got %q, wanted %q", got, want) 94 } 95 } 96 97 func TestBytesIOZeroOutFailure(t *testing.T) { 98 b := newBytesIOString("ABC") 99 n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) 100 if wantN, wantErr := int64(2), linuxerr.EFAULT; n != wantN || err != wantErr { 101 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 102 } 103 if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { 104 t.Errorf("Bytes: got %q, wanted %q", got, want) 105 } 106 } 107 108 func TestBytesIOCopyOutFromSuccess(t *testing.T) { 109 b := newBytesIOString("ABCDEFGH") 110 n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 111 {Start: 4, End: 7}, 112 {Start: 1, End: 4}, 113 }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) 114 if wantN := int64(6); n != wantN || err != nil { 115 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) 116 } 117 if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { 118 t.Errorf("Bytes: got %q, wanted %q", got, want) 119 } 120 } 121 122 func TestBytesIOCopyOutFromFailure(t *testing.T) { 123 b := newBytesIOString("ABCDE") 124 n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 125 {Start: 1, End: 4}, 126 {Start: 4, End: 7}, 127 }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) 128 if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr { 129 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 130 } 131 if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { 132 t.Errorf("Bytes: got %q, wanted %q", got, want) 133 } 134 } 135 136 func TestBytesIOCopyInToSuccess(t *testing.T) { 137 b := newBytesIOString("AfoobarH") 138 var dst bytes.Buffer 139 n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 140 {Start: 4, End: 7}, 141 {Start: 1, End: 4}, 142 }), safemem.FromIOWriter{&dst}, IOOpts{}) 143 if wantN := int64(6); n != wantN || err != nil { 144 t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) 145 } 146 if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { 147 t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) 148 } 149 } 150 151 func TestBytesIOCopyInToFailure(t *testing.T) { 152 b := newBytesIOString("Afoob") 153 var dst bytes.Buffer 154 n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 155 {Start: 1, End: 4}, 156 {Start: 4, End: 7}, 157 }), safemem.FromIOWriter{&dst}, IOOpts{}) 158 if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr { 159 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 160 } 161 if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { 162 t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) 163 } 164 } 165 166 type testStruct struct { 167 Int8 int8 168 Uint8 uint8 169 Int16 int16 170 Uint16 uint16 171 Int32 int32 172 Uint32 uint32 173 Int64 int64 174 Uint64 uint64 175 } 176 177 func TestCopyStringInShort(t *testing.T) { 178 // Tests for string length <= copyStringIncrement. 179 want := strings.Repeat("A", copyStringIncrement-2) 180 mem := want + "\x00" 181 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { 182 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 183 } 184 } 185 186 func TestCopyStringInLong(t *testing.T) { 187 // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen 188 // (requiring multiple calls to IO.CopyIn()). 189 want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) 190 mem := want + "\x00" 191 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { 192 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 193 } 194 } 195 196 func TestCopyStringInVeryLong(t *testing.T) { 197 // Tests for string length > copyStringMaxInitBufLen (requiring buffer 198 // reallocation). 199 want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) 200 mem := want + "\x00" 201 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { 202 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 203 } 204 } 205 206 func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { 207 want := strings.Repeat("A", copyStringIncrement-1) 208 got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) 209 if wantErr := linuxerr.EFAULT; got != want || err != wantErr { 210 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) 211 } 212 } 213 214 func TestCopyStringInTruncatedByMaxlen(t *testing.T) { 215 got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) 216 if want, wantErr := strings.Repeat("A", 5), linuxerr.ENAMETOOLONG; got != want || err != wantErr { 217 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) 218 } 219 } 220 221 func TestCopyInt32StringsInVec(t *testing.T) { 222 for _, test := range []struct { 223 str string 224 n int 225 initial []int32 226 final []int32 227 }{ 228 { 229 str: "100 200", 230 n: len("100 200"), 231 initial: []int32{1, 2}, 232 final: []int32{100, 200}, 233 }, 234 { 235 // Fewer values ok 236 str: "100", 237 n: len("100"), 238 initial: []int32{1, 2}, 239 final: []int32{100, 2}, 240 }, 241 { 242 // Extra values ok 243 str: "100 200 300", 244 n: len("100 200 "), 245 initial: []int32{1, 2}, 246 final: []int32{100, 200}, 247 }, 248 { 249 // Leading and trailing whitespace ok 250 str: " 100\t200\n", 251 n: len(" 100\t200\n"), 252 initial: []int32{1, 2}, 253 final: []int32{100, 200}, 254 }, 255 } { 256 t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { 257 src := BytesIOSequence([]byte(test.str)) 258 dsts := append([]int32(nil), test.initial...) 259 if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { 260 t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) 261 } 262 if !slices.Equal(dsts, test.final) { 263 t.Errorf("dsts: got %v, wanted %v", dsts, test.final) 264 } 265 }) 266 } 267 } 268 269 func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { 270 for _, s := range []string{"", "\n", "a123"} { 271 t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { 272 src := BytesIOSequence([]byte(s)) 273 initial := []int32{1, 2} 274 dsts := append([]int32(nil), initial...) 275 if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) { 276 t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL) 277 } 278 if !slices.Equal(dsts, initial) { 279 t.Errorf("dsts: got %v, wanted %v", dsts, initial) 280 } 281 }) 282 } 283 } 284 285 func TestIOSequenceCopyOut(t *testing.T) { 286 buf := []byte("ABCD") 287 s := BytesIOSequence(buf) 288 289 // CopyOut limited by len(src). 290 n, err := s.CopyOut(newContext(), []byte("fo")) 291 if wantN := 2; n != wantN || err != nil { 292 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 293 } 294 if want := []byte("foCD"); !bytes.Equal(buf, want) { 295 t.Errorf("buf: got %q, wanted %q", buf, want) 296 } 297 s = s.DropFirst(2) 298 if got, want := s.NumBytes(), int64(2); got != want { 299 t.Errorf("NumBytes: got %v, wanted %v", got, want) 300 } 301 302 // CopyOut limited by s.NumBytes(). 303 n, err = s.CopyOut(newContext(), []byte("obar")) 304 if wantN := 2; n != wantN || err != nil { 305 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 306 } 307 if want := []byte("foob"); !bytes.Equal(buf, want) { 308 t.Errorf("buf: got %q, wanted %q", buf, want) 309 } 310 s = s.DropFirst(2) 311 if got, want := s.NumBytes(), int64(0); got != want { 312 t.Errorf("NumBytes: got %v, wanted %v", got, want) 313 } 314 } 315 316 func TestIOSequenceCopyIn(t *testing.T) { 317 s := BytesIOSequence([]byte("foob")) 318 dst := []byte("ABCDEF") 319 320 // CopyIn limited by len(dst). 321 n, err := s.CopyIn(newContext(), dst[:2]) 322 if wantN := 2; n != wantN || err != nil { 323 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 324 } 325 if want := []byte("foCDEF"); !bytes.Equal(dst, want) { 326 t.Errorf("dst: got %q, wanted %q", dst, want) 327 } 328 s = s.DropFirst(2) 329 if got, want := s.NumBytes(), int64(2); got != want { 330 t.Errorf("NumBytes: got %v, wanted %v", got, want) 331 } 332 333 // CopyIn limited by s.Remaining(). 334 n, err = s.CopyIn(newContext(), dst[2:]) 335 if wantN := 2; n != wantN || err != nil { 336 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 337 } 338 if want := []byte("foobEF"); !bytes.Equal(dst, want) { 339 t.Errorf("dst: got %q, wanted %q", dst, want) 340 } 341 s = s.DropFirst(2) 342 if got, want := s.NumBytes(), int64(0); got != want { 343 t.Errorf("NumBytes: got %v, wanted %v", got, want) 344 } 345 } 346 347 func TestIOSequenceZeroOut(t *testing.T) { 348 buf := []byte("ABCD") 349 s := BytesIOSequence(buf) 350 351 // ZeroOut limited by toZero. 352 n, err := s.ZeroOut(newContext(), 2) 353 if wantN := int64(2); n != wantN || err != nil { 354 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 355 } 356 if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { 357 t.Errorf("buf: got %q, wanted %q", buf, want) 358 } 359 s = s.DropFirst(2) 360 if got, want := s.NumBytes(), int64(2); got != want { 361 t.Errorf("NumBytes: got %v, wanted %v", got, want) 362 } 363 364 // ZeroOut limited by s.NumBytes(). 365 n, err = s.ZeroOut(newContext(), 4) 366 if wantN := int64(2); n != wantN || err != nil { 367 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 368 } 369 if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { 370 t.Errorf("buf: got %q, wanted %q", buf, want) 371 } 372 s = s.DropFirst(2) 373 if got, want := s.NumBytes(), int64(0); got != want { 374 t.Errorf("NumBytes: got %v, wanted %v", got, want) 375 } 376 } 377 378 func TestIOSequenceTakeFirst(t *testing.T) { 379 s := BytesIOSequence([]byte("foobar")) 380 if got, want := s.NumBytes(), int64(6); got != want { 381 t.Errorf("NumBytes: got %v, wanted %v", got, want) 382 } 383 384 s = s.TakeFirst(3) 385 if got, want := s.NumBytes(), int64(3); got != want { 386 t.Errorf("NumBytes: got %v, wanted %v", got, want) 387 } 388 389 // TakeFirst(n) where n > s.NumBytes() is a no-op. 390 s = s.TakeFirst(9) 391 if got, want := s.NumBytes(), int64(3); got != want { 392 t.Errorf("NumBytes: got %v, wanted %v", got, want) 393 } 394 395 var dst [3]byte 396 n, err := s.CopyIn(newContext(), dst[:]) 397 if wantN := 3; n != wantN || err != nil { 398 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 399 } 400 if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { 401 t.Errorf("dst: got %q, wanted %q", got, want) 402 } 403 s = s.DropFirst(3) 404 if got, want := s.NumBytes(), int64(0); got != want { 405 t.Errorf("NumBytes: got %v, wanted %v", got, want) 406 } 407 }