github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/usermem/usermem_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package usermem
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/SagerNet/gvisor/pkg/context"
    25  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    26  	"github.com/SagerNet/gvisor/pkg/hostarch"
    27  	"github.com/SagerNet/gvisor/pkg/safemem"
    28  	"github.com/SagerNet/gvisor/pkg/syserror"
    29  )
    30  
    31  // newContext returns a context.Context that we can use in these tests (we
    32  // can't use contexttest because it depends on usermem).
    33  func newContext() context.Context {
    34  	return context.Background()
    35  }
    36  
    37  func newBytesIOString(s string) *BytesIO {
    38  	return &BytesIO{[]byte(s)}
    39  }
    40  
    41  func TestBytesIOCopyOutSuccess(t *testing.T) {
    42  	b := newBytesIOString("ABCDE")
    43  	n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
    44  	if wantN := 3; n != wantN || err != nil {
    45  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    46  	}
    47  	if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) {
    48  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    49  	}
    50  }
    51  
    52  func TestBytesIOCopyOutFailure(t *testing.T) {
    53  	b := newBytesIOString("ABC")
    54  	n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
    55  	if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
    56  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
    57  	}
    58  	if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) {
    59  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    60  	}
    61  }
    62  
    63  func TestBytesIOCopyInSuccess(t *testing.T) {
    64  	b := newBytesIOString("AfooE")
    65  	var dst [3]byte
    66  	n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
    67  	if wantN := 3; n != wantN || err != nil {
    68  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    69  	}
    70  	if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
    71  		t.Errorf("dst: got %q, wanted %q", got, want)
    72  	}
    73  }
    74  
    75  func TestBytesIOCopyInFailure(t *testing.T) {
    76  	b := newBytesIOString("Afo")
    77  	var dst [3]byte
    78  	n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
    79  	if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
    80  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
    81  	}
    82  	if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) {
    83  		t.Errorf("dst: got %q, wanted %q", got, want)
    84  	}
    85  }
    86  
    87  func TestBytesIOZeroOutSuccess(t *testing.T) {
    88  	b := newBytesIOString("ABCD")
    89  	n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{})
    90  	if wantN := int64(2); n != wantN || err != nil {
    91  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    92  	}
    93  	if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) {
    94  		t.Errorf("Bytes: got %q, wanted %q", got, want)
    95  	}
    96  }
    97  
    98  func TestBytesIOZeroOutFailure(t *testing.T) {
    99  	b := newBytesIOString("ABC")
   100  	n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{})
   101  	if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr {
   102  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   103  	}
   104  	if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) {
   105  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   106  	}
   107  }
   108  
   109  func TestBytesIOCopyOutFromSuccess(t *testing.T) {
   110  	b := newBytesIOString("ABCDEFGH")
   111  	n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   112  		{Start: 4, End: 7},
   113  		{Start: 1, End: 4},
   114  	}), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{})
   115  	if wantN := int64(6); n != wantN || err != nil {
   116  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   117  	}
   118  	if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) {
   119  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   120  	}
   121  }
   122  
   123  func TestBytesIOCopyOutFromFailure(t *testing.T) {
   124  	b := newBytesIOString("ABCDE")
   125  	n, err := b.CopyOutFrom(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   126  		{Start: 1, End: 4},
   127  		{Start: 4, End: 7},
   128  	}), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{})
   129  	if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
   130  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   131  	}
   132  	if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) {
   133  		t.Errorf("Bytes: got %q, wanted %q", got, want)
   134  	}
   135  }
   136  
   137  func TestBytesIOCopyInToSuccess(t *testing.T) {
   138  	b := newBytesIOString("AfoobarH")
   139  	var dst bytes.Buffer
   140  	n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   141  		{Start: 4, End: 7},
   142  		{Start: 1, End: 4},
   143  	}), safemem.FromIOWriter{&dst}, IOOpts{})
   144  	if wantN := int64(6); n != wantN || err != nil {
   145  		t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   146  	}
   147  	if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) {
   148  		t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
   149  	}
   150  }
   151  
   152  func TestBytesIOCopyInToFailure(t *testing.T) {
   153  	b := newBytesIOString("Afoob")
   154  	var dst bytes.Buffer
   155  	n, err := b.CopyInTo(newContext(), hostarch.AddrRangeSeqFromSlice([]hostarch.AddrRange{
   156  		{Start: 1, End: 4},
   157  		{Start: 4, End: 7},
   158  	}), safemem.FromIOWriter{&dst}, IOOpts{})
   159  	if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
   160  		t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   161  	}
   162  	if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
   163  		t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
   164  	}
   165  }
   166  
   167  type testStruct struct {
   168  	Int8   int8
   169  	Uint8  uint8
   170  	Int16  int16
   171  	Uint16 uint16
   172  	Int32  int32
   173  	Uint32 uint32
   174  	Int64  int64
   175  	Uint64 uint64
   176  }
   177  
   178  func TestCopyStringInShort(t *testing.T) {
   179  	// Tests for string length <= copyStringIncrement.
   180  	want := strings.Repeat("A", copyStringIncrement-2)
   181  	mem := want + "\x00"
   182  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
   183  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   184  	}
   185  }
   186  
   187  func TestCopyStringInLong(t *testing.T) {
   188  	// Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
   189  	// (requiring multiple calls to IO.CopyIn()).
   190  	want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
   191  	mem := want + "\x00"
   192  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
   193  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   194  	}
   195  }
   196  
   197  func TestCopyStringInVeryLong(t *testing.T) {
   198  	// Tests for string length > copyStringMaxInitBufLen (requiring buffer
   199  	// reallocation).
   200  	want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
   201  	mem := want + "\x00"
   202  	if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
   203  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
   204  	}
   205  }
   206  
   207  func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
   208  	want := strings.Repeat("A", copyStringIncrement-1)
   209  	got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})
   210  	if wantErr := syserror.EFAULT; got != want || err != wantErr {
   211  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
   212  	}
   213  }
   214  
   215  func TestCopyStringInTruncatedByMaxlen(t *testing.T) {
   216  	got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{})
   217  	if want, wantErr := strings.Repeat("A", 5), linuxerr.ENAMETOOLONG; got != want || err != wantErr {
   218  		t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
   219  	}
   220  }
   221  
   222  func TestCopyInt32StringsInVec(t *testing.T) {
   223  	for _, test := range []struct {
   224  		str     string
   225  		n       int
   226  		initial []int32
   227  		final   []int32
   228  	}{
   229  		{
   230  			str:     "100 200",
   231  			n:       len("100 200"),
   232  			initial: []int32{1, 2},
   233  			final:   []int32{100, 200},
   234  		},
   235  		{
   236  			// Fewer values ok
   237  			str:     "100",
   238  			n:       len("100"),
   239  			initial: []int32{1, 2},
   240  			final:   []int32{100, 2},
   241  		},
   242  		{
   243  			// Extra values ok
   244  			str:     "100 200 300",
   245  			n:       len("100 200 "),
   246  			initial: []int32{1, 2},
   247  			final:   []int32{100, 200},
   248  		},
   249  		{
   250  			// Leading and trailing whitespace ok
   251  			str:     " 100\t200\n",
   252  			n:       len(" 100\t200\n"),
   253  			initial: []int32{1, 2},
   254  			final:   []int32{100, 200},
   255  		},
   256  	} {
   257  		t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) {
   258  			src := BytesIOSequence([]byte(test.str))
   259  			dsts := append([]int32(nil), test.initial...)
   260  			if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil {
   261  				t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n)
   262  			}
   263  			if !reflect.DeepEqual(dsts, test.final) {
   264  				t.Errorf("dsts: got %v, wanted %v", dsts, test.final)
   265  			}
   266  		})
   267  	}
   268  }
   269  
   270  func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) {
   271  	for _, s := range []string{"", "\n", "a123"} {
   272  		t.Run(fmt.Sprintf("%q", s), func(t *testing.T) {
   273  			src := BytesIOSequence([]byte(s))
   274  			initial := []int32{1, 2}
   275  			dsts := append([]int32(nil), initial...)
   276  			if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); !linuxerr.Equals(linuxerr.EINVAL, err) {
   277  				t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, linuxerr.EINVAL)
   278  			}
   279  			if !reflect.DeepEqual(dsts, initial) {
   280  				t.Errorf("dsts: got %v, wanted %v", dsts, initial)
   281  			}
   282  		})
   283  	}
   284  }
   285  
   286  func TestIOSequenceCopyOut(t *testing.T) {
   287  	buf := []byte("ABCD")
   288  	s := BytesIOSequence(buf)
   289  
   290  	// CopyOut limited by len(src).
   291  	n, err := s.CopyOut(newContext(), []byte("fo"))
   292  	if wantN := 2; n != wantN || err != nil {
   293  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   294  	}
   295  	if want := []byte("foCD"); !bytes.Equal(buf, want) {
   296  		t.Errorf("buf: got %q, wanted %q", buf, want)
   297  	}
   298  	s = s.DropFirst(2)
   299  	if got, want := s.NumBytes(), int64(2); got != want {
   300  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   301  	}
   302  
   303  	// CopyOut limited by s.NumBytes().
   304  	n, err = s.CopyOut(newContext(), []byte("obar"))
   305  	if wantN := 2; n != wantN || err != nil {
   306  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   307  	}
   308  	if want := []byte("foob"); !bytes.Equal(buf, want) {
   309  		t.Errorf("buf: got %q, wanted %q", buf, want)
   310  	}
   311  	s = s.DropFirst(2)
   312  	if got, want := s.NumBytes(), int64(0); got != want {
   313  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   314  	}
   315  }
   316  
   317  func TestIOSequenceCopyIn(t *testing.T) {
   318  	s := BytesIOSequence([]byte("foob"))
   319  	dst := []byte("ABCDEF")
   320  
   321  	// CopyIn limited by len(dst).
   322  	n, err := s.CopyIn(newContext(), dst[:2])
   323  	if wantN := 2; n != wantN || err != nil {
   324  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   325  	}
   326  	if want := []byte("foCDEF"); !bytes.Equal(dst, want) {
   327  		t.Errorf("dst: got %q, wanted %q", dst, want)
   328  	}
   329  	s = s.DropFirst(2)
   330  	if got, want := s.NumBytes(), int64(2); got != want {
   331  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   332  	}
   333  
   334  	// CopyIn limited by s.Remaining().
   335  	n, err = s.CopyIn(newContext(), dst[2:])
   336  	if wantN := 2; n != wantN || err != nil {
   337  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   338  	}
   339  	if want := []byte("foobEF"); !bytes.Equal(dst, want) {
   340  		t.Errorf("dst: got %q, wanted %q", dst, want)
   341  	}
   342  	s = s.DropFirst(2)
   343  	if got, want := s.NumBytes(), int64(0); got != want {
   344  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   345  	}
   346  }
   347  
   348  func TestIOSequenceZeroOut(t *testing.T) {
   349  	buf := []byte("ABCD")
   350  	s := BytesIOSequence(buf)
   351  
   352  	// ZeroOut limited by toZero.
   353  	n, err := s.ZeroOut(newContext(), 2)
   354  	if wantN := int64(2); n != wantN || err != nil {
   355  		t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   356  	}
   357  	if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) {
   358  		t.Errorf("buf: got %q, wanted %q", buf, want)
   359  	}
   360  	s = s.DropFirst(2)
   361  	if got, want := s.NumBytes(), int64(2); got != want {
   362  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   363  	}
   364  
   365  	// ZeroOut limited by s.NumBytes().
   366  	n, err = s.ZeroOut(newContext(), 4)
   367  	if wantN := int64(2); n != wantN || err != nil {
   368  		t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   369  	}
   370  	if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) {
   371  		t.Errorf("buf: got %q, wanted %q", buf, want)
   372  	}
   373  	s = s.DropFirst(2)
   374  	if got, want := s.NumBytes(), int64(0); got != want {
   375  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   376  	}
   377  }
   378  
   379  func TestIOSequenceTakeFirst(t *testing.T) {
   380  	s := BytesIOSequence([]byte("foobar"))
   381  	if got, want := s.NumBytes(), int64(6); got != want {
   382  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   383  	}
   384  
   385  	s = s.TakeFirst(3)
   386  	if got, want := s.NumBytes(), int64(3); got != want {
   387  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   388  	}
   389  
   390  	// TakeFirst(n) where n > s.NumBytes() is a no-op.
   391  	s = s.TakeFirst(9)
   392  	if got, want := s.NumBytes(), int64(3); got != want {
   393  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   394  	}
   395  
   396  	var dst [3]byte
   397  	n, err := s.CopyIn(newContext(), dst[:])
   398  	if wantN := 3; n != wantN || err != nil {
   399  		t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   400  	}
   401  	if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
   402  		t.Errorf("dst: got %q, wanted %q", got, want)
   403  	}
   404  	s = s.DropFirst(3)
   405  	if got, want := s.NumBytes(), int64(0); got != want {
   406  		t.Errorf("NumBytes: got %v, wanted %v", got, want)
   407  	}
   408  }