gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/usermem/usermem_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package usermem
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"slices"
    21  	"strings"
    22  	"testing"
    23  
    24  	"gvisor.dev/gvisor/pkg/context"
    25  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    26  	"gvisor.dev/gvisor/pkg/hostarch"
    27  	"gvisor.dev/gvisor/pkg/safemem"
    28  )
    29  
    30  // newContext returns a context.Context that we can use in these tests (we
    31  // can't use contexttest because it depends on usermem).
    32  func newContext() context.Context {
    33  	return context.Background()
    34  }
    35  
    36  func newBytesIOString(s string) *BytesIO {
    37  	return &BytesIO{[]byte(s)}
    38  }
    39  
    40  func TestBytesIOCopyOutSuccess(t *testing.T) {
    41  	b := newBytesIOString("ABCDE")
    42  	n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
    43  	if wantN := 3; n != wantN || err != nil {
    44  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    45  	}
    46  	if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) {
    47  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    48  	}
    49  }
    50  
    51  func TestBytesIOCopyOutFailure(t *testing.T) {
    52  	b := newBytesIOString("ABC")
    53  	n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
    54  	if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr {
    55  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
    56  	}
    57  	if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) {
    58  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    59  	}
    60  }
    61  
    62  func TestBytesIOCopyInSuccess(t *testing.T) {
    63  	b := newBytesIOString("AfooE")
    64  	var dst [3]byte
    65  	n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
    66  	if wantN := 3; n != wantN || err != nil {
    67  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    68  	}
    69  	if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
    70  		t.Errorf("dst: got %q, wanted %q", got, want)
    71  	}
    72  }
    73  
    74  func TestBytesIOCopyInFailure(t *testing.T) {
    75  	b := newBytesIOString("Afo")
    76  	var dst [3]byte
    77  	n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
    78  	if wantN, wantErr := 2, linuxerr.EFAULT; n != wantN || err != wantErr {
    79  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
    80  	}
    81  	if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) {
    82  		t.Errorf("dst: got %q, wanted %q", got, want)
    83  	}
    84  }
    85  
    86  func TestBytesIOZeroOutSuccess(t *testing.T) {
    87  	b := newBytesIOString("ABCD")
    88  	n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{})
    89  	if wantN := int64(2); n != wantN || err != nil {
    90  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    91  	}
    92  	if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) {
    93  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    94  	}
    95  }
    96  
    97  func TestBytesIOZeroOutFailure(t *testing.T) {
    98  	b := newBytesIOString("ABC")
    99  	n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{})
   100  	if wantN, wantErr := int64(2), linuxerr.EFAULT; n != wantN || err != wantErr {
   101  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   102  	}
   103  	if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) {
   104  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   105  	}
   106  }
   107  
   108  func TestBytesIOCopyOutFromSuccess(t *testing.T) {
   109  	b := newBytesIOString("ABCDEFGH")
   110  	n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   111  		{Start: 4, End: 7},
   112  		{Start: 1, End: 4},
   113  	}), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{})
   114  	if wantN := int64(6); n != wantN || err != nil {
   115  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   116  	}
   117  	if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) {
   118  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   119  	}
   120  }
   121  
   122  func TestBytesIOCopyOutFromFailure(t *testing.T) {
   123  	b := newBytesIOString("ABCDE")
   124  	n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   125  		{Start: 1, End: 4},
   126  		{Start: 4, End: 7},
   127  	}), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{})
   128  	if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr {
   129  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   130  	}
   131  	if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) {
   132  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   133  	}
   134  }
   135  
   136  func TestBytesIOCopyInToSuccess(t *testing.T) {
   137  	b := newBytesIOString("AfoobarH")
   138  	var dst bytes.Buffer
   139  	n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   140  		{Start: 4, End: 7},
   141  		{Start: 1, End: 4},
   142  	}), safemem.FromIOWriter{&dst}, IOOpts{})
   143  	if wantN := int64(6); n != wantN || err != nil {
   144  		t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   145  	}
   146  	if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) {
   147  		t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
   148  	}
   149  }
   150  
   151  func TestBytesIOCopyInToFailure(t *testing.T) {
   152  	b := newBytesIOString("Afoob")
   153  	var dst bytes.Buffer
   154  	n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   155  		{Start: 1, End: 4},
   156  		{Start: 4, End: 7},
   157  	}), safemem.FromIOWriter{&dst}, IOOpts{})
   158  	if wantN, wantErr := int64(4), linuxerr.EFAULT; n != wantN || err != wantErr {
   159  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   160  	}
   161  	if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
   162  		t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
   163  	}
   164  }
   165  
   166  type testStruct struct {
   167  	Int8   int8
   168  	Uint8  uint8
   169  	Int16  int16
   170  	Uint16 uint16
   171  	Int32  int32
   172  	Uint32 uint32
   173  	Int64  int64
   174  	Uint64 uint64
   175  }
   176  
   177  func TestCopyStringInShort(t *testing.T) {
   178  	// Tests for string length <= copyStringIncrement.
   179  	want := strings.Repeat("A", copyStringIncrement-2)
   180  	mem := want + "\x00"
   181  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
   182  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   183  	}
   184  }
   185  
   186  func TestCopyStringInLong(t *testing.T) {
   187  	// Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
   188  	// (requiring multiple calls to IO.CopyIn()).
   189  	want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
   190  	mem := want + "\x00"
   191  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
   192  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   193  	}
   194  }
   195  
   196  func TestCopyStringInVeryLong(t *testing.T) {
   197  	// Tests for string length > copyStringMaxInitBufLen (requiring buffer
   198  	// reallocation).
   199  	want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
   200  	mem := want + "\x00"
   201  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
   202  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   203  	}
   204  }
   205  
   206  func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
   207  	want := strings.Repeat("A", copyStringIncrement-1)
   208  	got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})
   209  	if wantErr := linuxerr.EFAULT; got != want || err != wantErr {
   210  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
   211  	}
   212  }
   213  
   214  func TestCopyStringInTruncatedByMaxlen(t *testing.T) {
   215  	got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{})
   216  	if want, wantErr := strings.Repeat("A", 5), linuxerr.ENAMETOOLONG; got != want || err != wantErr {
   217  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
   218  	}
   219  }
   220  
   221  func TestCopyInt32StringsInVec(t *testing.T) {
   222  	for _, test := range []struct {
   223  		str     string
   224  		n       int
   225  		initial []int32
   226  		final   []int32
   227  	}{
   228  		{
   229  			str:     "100 200",
   230  			n:       len("100 200"),
   231  			initial: []int32{1, 2},
   232  			final:   []int32{100, 200},
   233  		},
   234  		{
   235  			// Fewer values ok
   236  			str:     "100",
   237  			n:       len("100"),
   238  			initial: []int32{1, 2},
   239  			final:   []int32{100, 2},
   240  		},
   241  		{
   242  			// Extra values ok
   243  			str:     "100 200 300",
   244  			n:       len("100 200 "),
   245  			initial: []int32{1, 2},
   246  			final:   []int32{100, 200},
   247  		},
   248  		{
   249  			// Leading and trailing whitespace ok
   250  			str:     " 100\t200\n",
   251  			n:       len(" 100\t200\n"),
   252  			initial: []int32{1, 2},
   253  			final:   []int32{100, 200},
   254  		},
   255  	} {
   256  		t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) {
   257  			src := BytesIOSequence([]byte(test.str))
   258  			dsts := append([]int32(nil), test.initial...)
   259  			if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil {
   260  				t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n)
   261  			}
   262  			if !slices.Equal(dsts, test.final) {
   263  				t.Errorf("dsts: got %v, wanted %v", dsts, test.final)
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) {
   270  	for _, s := range []string{"", "\n", "a123"} {
   271  		t.Run(fmt.Sprintf("%q", s), func(t *testing.T) {
   272  			src := BytesIOSequence([]byte(s))
   273  			initial := []int32{1, 2}
   274  			dsts := append([]int32(nil), initial...)
   275  			if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) {
   276  				t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL)
   277  			}
   278  			if !slices.Equal(dsts, initial) {
   279  				t.Errorf("dsts: got %v, wanted %v", dsts, initial)
   280  			}
   281  		})
   282  	}
   283  }
   284  
   285  func TestIOSequenceCopyOut(t *testing.T) {
   286  	buf := []byte("ABCD")
   287  	s := BytesIOSequence(buf)
   288  
   289  	// CopyOut limited by len(src).
   290  	n, err := s.CopyOut(newContext(), []byte("fo"))
   291  	if wantN := 2; n != wantN || err != nil {
   292  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   293  	}
   294  	if want := []byte("foCD"); !bytes.Equal(buf, want) {
   295  		t.Errorf("buf: got %q, wanted %q", buf, want)
   296  	}
   297  	s = s.DropFirst(2)
   298  	if got, want := s.NumBytes(), int64(2); got != want {
   299  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   300  	}
   301  
   302  	// CopyOut limited by s.NumBytes().
   303  	n, err = s.CopyOut(newContext(), []byte("obar"))
   304  	if wantN := 2; n != wantN || err != nil {
   305  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   306  	}
   307  	if want := []byte("foob"); !bytes.Equal(buf, want) {
   308  		t.Errorf("buf: got %q, wanted %q", buf, want)
   309  	}
   310  	s = s.DropFirst(2)
   311  	if got, want := s.NumBytes(), int64(0); got != want {
   312  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   313  	}
   314  }
   315  
   316  func TestIOSequenceCopyIn(t *testing.T) {
   317  	s := BytesIOSequence([]byte("foob"))
   318  	dst := []byte("ABCDEF")
   319  
   320  	// CopyIn limited by len(dst).
   321  	n, err := s.CopyIn(newContext(), dst[:2])
   322  	if wantN := 2; n != wantN || err != nil {
   323  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   324  	}
   325  	if want := []byte("foCDEF"); !bytes.Equal(dst, want) {
   326  		t.Errorf("dst: got %q, wanted %q", dst, want)
   327  	}
   328  	s = s.DropFirst(2)
   329  	if got, want := s.NumBytes(), int64(2); got != want {
   330  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   331  	}
   332  
   333  	// CopyIn limited by s.Remaining().
   334  	n, err = s.CopyIn(newContext(), dst[2:])
   335  	if wantN := 2; n != wantN || err != nil {
   336  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   337  	}
   338  	if want := []byte("foobEF"); !bytes.Equal(dst, want) {
   339  		t.Errorf("dst: got %q, wanted %q", dst, want)
   340  	}
   341  	s = s.DropFirst(2)
   342  	if got, want := s.NumBytes(), int64(0); got != want {
   343  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   344  	}
   345  }
   346  
   347  func TestIOSequenceZeroOut(t *testing.T) {
   348  	buf := []byte("ABCD")
   349  	s := BytesIOSequence(buf)
   350  
   351  	// ZeroOut limited by toZero.
   352  	n, err := s.ZeroOut(newContext(), 2)
   353  	if wantN := int64(2); n != wantN || err != nil {
   354  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   355  	}
   356  	if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) {
   357  		t.Errorf("buf: got %q, wanted %q", buf, want)
   358  	}
   359  	s = s.DropFirst(2)
   360  	if got, want := s.NumBytes(), int64(2); got != want {
   361  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   362  	}
   363  
   364  	// ZeroOut limited by s.NumBytes().
   365  	n, err = s.ZeroOut(newContext(), 4)
   366  	if wantN := int64(2); n != wantN || err != nil {
   367  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   368  	}
   369  	if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) {
   370  		t.Errorf("buf: got %q, wanted %q", buf, want)
   371  	}
   372  	s = s.DropFirst(2)
   373  	if got, want := s.NumBytes(), int64(0); got != want {
   374  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   375  	}
   376  }
   377  
   378  func TestIOSequenceTakeFirst(t *testing.T) {
   379  	s := BytesIOSequence([]byte("foobar"))
   380  	if got, want := s.NumBytes(), int64(6); got != want {
   381  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   382  	}
   383  
   384  	s = s.TakeFirst(3)
   385  	if got, want := s.NumBytes(), int64(3); got != want {
   386  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   387  	}
   388  
   389  	// TakeFirst(n) where n > s.NumBytes() is a no-op.
   390  	s = s.TakeFirst(9)
   391  	if got, want := s.NumBytes(), int64(3); got != want {
   392  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   393  	}
   394  
   395  	var dst [3]byte
   396  	n, err := s.CopyIn(newContext(), dst[:])
   397  	if wantN := 3; n != wantN || err != nil {
   398  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   399  	}
   400  	if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
   401  		t.Errorf("dst: got %q, wanted %q", got, want)
   402  	}
   403  	s = s.DropFirst(3)
   404  	if got, want := s.NumBytes(), int64(0); got != want {
   405  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   406  	}
   407  }