github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/usermem/usermem_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package usermem 16 17 import ( 18 "bytes" 19 "fmt" 20 "reflect" 21 "strings" 22 "testing" 23 24 "github.com/SagerNet/gvisor/pkg/context" 25 "github.com/SagerNet/gvisor/pkg/errors/linuxerr" 26 "github.com/SagerNet/gvisor/pkg/hostarch" 27 "github.com/SagerNet/gvisor/pkg/safemem" 28 "github.com/SagerNet/gvisor/pkg/syserror" 29 ) 30 31 // newContext returns a context.Context that we can use in these tests (we 32 // can't use contexttest because it depends on usermem). 33 func newContext() context.Context { 34 return context.Background() 35 } 36 37 func newBytesIOString(s string) *BytesIO { 38 return &BytesIO{[]byte(s)} 39 } 40 41 func TestBytesIOCopyOutSuccess(t *testing.T) { 42 b := newBytesIOString("ABCDE") 43 n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) 44 if wantN := 3; n != wantN || err != nil { 45 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 46 } 47 if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) { 48 t.Errorf("Bytes: got %q, wanted %q", got, want) 49 } 50 } 51 52 func TestBytesIOCopyOutFailure(t *testing.T) { 53 b := newBytesIOString("ABC") 54 n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{}) 55 if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { 56 t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 57 } 58 if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) { 59 t.Errorf("Bytes: got %q, wanted %q", got, want) 60 } 61 } 62 63 func TestBytesIOCopyInSuccess(t *testing.T) { 64 b := newBytesIOString("AfooE") 65 var dst [3]byte 66 n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) 67 if wantN := 3; n != wantN || err != nil { 68 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 69 } 70 if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { 71 t.Errorf("dst: got %q, wanted %q", got, want) 72 } 73 } 74 75 func TestBytesIOCopyInFailure(t *testing.T) { 76 b := newBytesIOString("Afo") 77 var dst [3]byte 78 n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{}) 79 if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr { 80 t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 81 } 82 if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) { 83 t.Errorf("dst: got %q, wanted %q", got, want) 84 } 85 } 86 87 func TestBytesIOZeroOutSuccess(t *testing.T) { 88 b := newBytesIOString("ABCD") 89 n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{}) 90 if wantN := int64(2); n != wantN || err != nil { 91 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 92 } 93 if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) { 94 t.Errorf("Bytes: got %q, wanted %q", got, want) 95 } 96 } 97 98 func TestBytesIOZeroOutFailure(t *testing.T) { 99 b := newBytesIOString("ABC") 100 n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{}) 101 if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr { 102 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 103 } 104 if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) { 105 t.Errorf("Bytes: got %q, wanted %q", got, want) 106 } 107 } 108 109 func TestBytesIOCopyOutFromSuccess(t *testing.T) { 110 b := newBytesIOString("ABCDEFGH") 111 n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 112 {Start: 4, End: 7}, 113 {Start: 1, End: 4}, 114 }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{}) 115 if wantN := int64(6); n != wantN || err != nil { 116 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN) 117 } 118 if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) { 119 t.Errorf("Bytes: got %q, wanted %q", got, want) 120 } 121 } 122 123 func TestBytesIOCopyOutFromFailure(t *testing.T) { 124 b := newBytesIOString("ABCDE") 125 n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 126 {Start: 1, End: 4}, 127 {Start: 4, End: 7}, 128 }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{}) 129 if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { 130 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 131 } 132 if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) { 133 t.Errorf("Bytes: got %q, wanted %q", got, want) 134 } 135 } 136 137 func TestBytesIOCopyInToSuccess(t *testing.T) { 138 b := newBytesIOString("AfoobarH") 139 var dst bytes.Buffer 140 n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 141 {Start: 4, End: 7}, 142 {Start: 1, End: 4}, 143 }), safemem.FromIOWriter{&dst}, IOOpts{}) 144 if wantN := int64(6); n != wantN || err != nil { 145 t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN) 146 } 147 if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) { 148 t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) 149 } 150 } 151 152 func TestBytesIOCopyInToFailure(t *testing.T) { 153 b := newBytesIOString("Afoob") 154 var dst bytes.Buffer 155 n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{ 156 {Start: 1, End: 4}, 157 {Start: 4, End: 7}, 158 }), safemem.FromIOWriter{&dst}, IOOpts{}) 159 if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr { 160 t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 161 } 162 if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { 163 t.Errorf("dst.Bytes(): got %q, wanted %q", got, want) 164 } 165 } 166 167 type testStruct struct { 168 Int8 int8 169 Uint8 uint8 170 Int16 int16 171 Uint16 uint16 172 Int32 int32 173 Uint32 uint32 174 Int64 int64 175 Uint64 uint64 176 } 177 178 func TestCopyStringInShort(t *testing.T) { 179 // Tests for string length <= copyStringIncrement. 180 want := strings.Repeat("A", copyStringIncrement-2) 181 mem := want + "\x00" 182 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { 183 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 184 } 185 } 186 187 func TestCopyStringInLong(t *testing.T) { 188 // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen 189 // (requiring multiple calls to IO.CopyIn()). 190 want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4) 191 mem := want + "\x00" 192 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil { 193 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 194 } 195 } 196 197 func TestCopyStringInVeryLong(t *testing.T) { 198 // Tests for string length > copyStringMaxInitBufLen (requiring buffer 199 // reallocation). 200 want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4) 201 mem := want + "\x00" 202 if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil { 203 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want) 204 } 205 } 206 207 func TestCopyStringInNoTerminatingZeroByte(t *testing.T) { 208 want := strings.Repeat("A", copyStringIncrement-1) 209 got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{}) 210 if wantErr := syserror.EFAULT; got != want || err != wantErr { 211 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) 212 } 213 } 214 215 func TestCopyStringInTruncatedByMaxlen(t *testing.T) { 216 got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{}) 217 if want, wantErr := strings.Repeat("A", 5), linuxerr.ENAMETOOLONG; got != want || err != wantErr { 218 t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr) 219 } 220 } 221 222 func TestCopyInt32StringsInVec(t *testing.T) { 223 for _, test := range []struct { 224 str string 225 n int 226 initial []int32 227 final []int32 228 }{ 229 { 230 str: "100 200", 231 n: len("100 200"), 232 initial: []int32{1, 2}, 233 final: []int32{100, 200}, 234 }, 235 { 236 // Fewer values ok 237 str: "100", 238 n: len("100"), 239 initial: []int32{1, 2}, 240 final: []int32{100, 2}, 241 }, 242 { 243 // Extra values ok 244 str: "100 200 300", 245 n: len("100 200 "), 246 initial: []int32{1, 2}, 247 final: []int32{100, 200}, 248 }, 249 { 250 // Leading and trailing whitespace ok 251 str: " 100\t200\n", 252 n: len(" 100\t200\n"), 253 initial: []int32{1, 2}, 254 final: []int32{100, 200}, 255 }, 256 } { 257 t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) { 258 src := BytesIOSequence([]byte(test.str)) 259 dsts := append([]int32(nil), test.initial...) 260 if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil { 261 t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n) 262 } 263 if !reflect.DeepEqual(dsts, test.final) { 264 t.Errorf("dsts: got %v, wanted %v", dsts, test.final) 265 } 266 }) 267 } 268 } 269 270 func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) { 271 for _, s := range []string{"", "\n", "a123"} { 272 t.Run(fmt.Sprintf("%q", s), func(t *testing.T) { 273 src := BytesIOSequence([]byte(s)) 274 initial := []int32{1, 2} 275 dsts := append([]int32(nil), initial...) 276 if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) { 277 t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL) 278 } 279 if !reflect.DeepEqual(dsts, initial) { 280 t.Errorf("dsts: got %v, wanted %v", dsts, initial) 281 } 282 }) 283 } 284 } 285 286 func TestIOSequenceCopyOut(t *testing.T) { 287 buf := []byte("ABCD") 288 s := BytesIOSequence(buf) 289 290 // CopyOut limited by len(src). 291 n, err := s.CopyOut(newContext(), []byte("fo")) 292 if wantN := 2; n != wantN || err != nil { 293 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 294 } 295 if want := []byte("foCD"); !bytes.Equal(buf, want) { 296 t.Errorf("buf: got %q, wanted %q", buf, want) 297 } 298 s = s.DropFirst(2) 299 if got, want := s.NumBytes(), int64(2); got != want { 300 t.Errorf("NumBytes: got %v, wanted %v", got, want) 301 } 302 303 // CopyOut limited by s.NumBytes(). 304 n, err = s.CopyOut(newContext(), []byte("obar")) 305 if wantN := 2; n != wantN || err != nil { 306 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 307 } 308 if want := []byte("foob"); !bytes.Equal(buf, want) { 309 t.Errorf("buf: got %q, wanted %q", buf, want) 310 } 311 s = s.DropFirst(2) 312 if got, want := s.NumBytes(), int64(0); got != want { 313 t.Errorf("NumBytes: got %v, wanted %v", got, want) 314 } 315 } 316 317 func TestIOSequenceCopyIn(t *testing.T) { 318 s := BytesIOSequence([]byte("foob")) 319 dst := []byte("ABCDEF") 320 321 // CopyIn limited by len(dst). 322 n, err := s.CopyIn(newContext(), dst[:2]) 323 if wantN := 2; n != wantN || err != nil { 324 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 325 } 326 if want := []byte("foCDEF"); !bytes.Equal(dst, want) { 327 t.Errorf("dst: got %q, wanted %q", dst, want) 328 } 329 s = s.DropFirst(2) 330 if got, want := s.NumBytes(), int64(2); got != want { 331 t.Errorf("NumBytes: got %v, wanted %v", got, want) 332 } 333 334 // CopyIn limited by s.Remaining(). 335 n, err = s.CopyIn(newContext(), dst[2:]) 336 if wantN := 2; n != wantN || err != nil { 337 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 338 } 339 if want := []byte("foobEF"); !bytes.Equal(dst, want) { 340 t.Errorf("dst: got %q, wanted %q", dst, want) 341 } 342 s = s.DropFirst(2) 343 if got, want := s.NumBytes(), int64(0); got != want { 344 t.Errorf("NumBytes: got %v, wanted %v", got, want) 345 } 346 } 347 348 func TestIOSequenceZeroOut(t *testing.T) { 349 buf := []byte("ABCD") 350 s := BytesIOSequence(buf) 351 352 // ZeroOut limited by toZero. 353 n, err := s.ZeroOut(newContext(), 2) 354 if wantN := int64(2); n != wantN || err != nil { 355 t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 356 } 357 if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) { 358 t.Errorf("buf: got %q, wanted %q", buf, want) 359 } 360 s = s.DropFirst(2) 361 if got, want := s.NumBytes(), int64(2); got != want { 362 t.Errorf("NumBytes: got %v, wanted %v", got, want) 363 } 364 365 // ZeroOut limited by s.NumBytes(). 366 n, err = s.ZeroOut(newContext(), 4) 367 if wantN := int64(2); n != wantN || err != nil { 368 t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN) 369 } 370 if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) { 371 t.Errorf("buf: got %q, wanted %q", buf, want) 372 } 373 s = s.DropFirst(2) 374 if got, want := s.NumBytes(), int64(0); got != want { 375 t.Errorf("NumBytes: got %v, wanted %v", got, want) 376 } 377 } 378 379 func TestIOSequenceTakeFirst(t *testing.T) { 380 s := BytesIOSequence([]byte("foobar")) 381 if got, want := s.NumBytes(), int64(6); got != want { 382 t.Errorf("NumBytes: got %v, wanted %v", got, want) 383 } 384 385 s = s.TakeFirst(3) 386 if got, want := s.NumBytes(), int64(3); got != want { 387 t.Errorf("NumBytes: got %v, wanted %v", got, want) 388 } 389 390 // TakeFirst(n) where n > s.NumBytes() is a no-op. 391 s = s.TakeFirst(9) 392 if got, want := s.NumBytes(), int64(3); got != want { 393 t.Errorf("NumBytes: got %v, wanted %v", got, want) 394 } 395 396 var dst [3]byte 397 n, err := s.CopyIn(newContext(), dst[:]) 398 if wantN := 3; n != wantN || err != nil { 399 t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN) 400 } 401 if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) { 402 t.Errorf("dst: got %q, wanted %q", got, want) 403 } 404 s = s.DropFirst(3) 405 if got, want := s.NumBytes(), int64(0); got != want { 406 t.Errorf("NumBytes: got %v, wanted %v", got, want) 407 } 408 }