github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/safemem/io_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safemem
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  	"testing"
    21  )
    22  
    23  func makeBlocks(slices ...[]byte) []Block {
    24  	blocks := make([]Block, 0, len(slices))
    25  	for _, s := range slices {
    26  		blocks = append(blocks, BlockFromSafeSlice(s))
    27  	}
    28  	return blocks
    29  }
    30  
    31  func TestFromIOReaderFullRead(t *testing.T) {
    32  	r := FromIOReader{bytes.NewBufferString("foobar")}
    33  	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
    34  	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
    35  	if wantN := uint64(6); n != wantN || err != nil {
    36  		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    37  	}
    38  	for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
    39  		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
    40  			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
    41  		}
    42  	}
    43  }
    44  
    45  type eofHidingReader struct {
    46  	Reader io.Reader
    47  }
    48  
    49  func (r eofHidingReader) Read(dst []byte) (int, error) {
    50  	n, err := r.Reader.Read(dst)
    51  	if err == io.EOF {
    52  		return n, nil
    53  	}
    54  	return n, err
    55  }
    56  
    57  func TestFromIOReaderPartialRead(t *testing.T) {
    58  	r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
    59  	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
    60  	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
    61  	// FromIOReader should stop after the eofHidingReader returns (1, nil)
    62  	// for a 3-byte read.
    63  	if wantN := uint64(4); n != wantN || err != nil {
    64  		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    65  	}
    66  	for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
    67  		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
    68  			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
    69  		}
    70  	}
    71  }
    72  
    73  type singleByteReader struct {
    74  	Reader io.Reader
    75  }
    76  
    77  func (r singleByteReader) Read(dst []byte) (int, error) {
    78  	if len(dst) == 0 {
    79  		return r.Reader.Read(dst)
    80  	}
    81  	return r.Reader.Read(dst[:1])
    82  }
    83  
    84  func TestSingleByteReader(t *testing.T) {
    85  	r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
    86  	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
    87  	n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
    88  	// FromIOReader should stop after the singleByteReader returns (1, nil)
    89  	// for a 3-byte read.
    90  	if wantN := uint64(1); n != wantN || err != nil {
    91  		t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
    92  	}
    93  	for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
    94  		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
    95  			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
    96  		}
    97  	}
    98  }
    99  
   100  func TestReadFullToBlocks(t *testing.T) {
   101  	r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
   102  	dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
   103  	n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
   104  	// ReadFullToBlocks should call into FromIOReader => singleByteReader
   105  	// repeatedly until dsts is exhausted.
   106  	if wantN := uint64(6); n != wantN || err != nil {
   107  		t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   108  	}
   109  	for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
   110  		if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
   111  			t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
   112  		}
   113  	}
   114  }
   115  
   116  func TestFromIOWriterFullWrite(t *testing.T) {
   117  	srcs := makeBlocks([]byte("foo"), []byte("bar"))
   118  	var dst bytes.Buffer
   119  	w := FromIOWriter{&dst}
   120  	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
   121  	if wantN := uint64(6); n != wantN || err != nil {
   122  		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   123  	}
   124  	if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
   125  		t.Errorf("dst: got %q, wanted %q", got, want)
   126  	}
   127  }
   128  
   129  type limitedWriter struct {
   130  	Writer io.Writer
   131  	Done   int
   132  	Limit  int
   133  }
   134  
   135  func (w *limitedWriter) Write(src []byte) (int, error) {
   136  	count := len(src)
   137  	if count > (w.Limit - w.Done) {
   138  		count = w.Limit - w.Done
   139  	}
   140  	n, err := w.Writer.Write(src[:count])
   141  	w.Done += n
   142  	return n, err
   143  }
   144  
   145  func TestFromIOWriterPartialWrite(t *testing.T) {
   146  	srcs := makeBlocks([]byte("foo"), []byte("bar"))
   147  	var dst bytes.Buffer
   148  	w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
   149  	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
   150  	// FromIOWriter should stop after the limitedWriter returns (1, nil) for a
   151  	// 3-byte write.
   152  	if wantN := uint64(4); n != wantN || err != nil {
   153  		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   154  	}
   155  	if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
   156  		t.Errorf("dst: got %q, wanted %q", got, want)
   157  	}
   158  }
   159  
   160  type singleByteWriter struct {
   161  	Writer io.Writer
   162  }
   163  
   164  func (w singleByteWriter) Write(src []byte) (int, error) {
   165  	if len(src) == 0 {
   166  		return w.Writer.Write(src)
   167  	}
   168  	return w.Writer.Write(src[:1])
   169  }
   170  
   171  func TestSingleByteWriter(t *testing.T) {
   172  	srcs := makeBlocks([]byte("foo"), []byte("bar"))
   173  	var dst bytes.Buffer
   174  	w := FromIOWriter{singleByteWriter{&dst}}
   175  	n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
   176  	// FromIOWriter should stop after the singleByteWriter returns (1, nil)
   177  	// for a 3-byte write.
   178  	if wantN := uint64(1); n != wantN || err != nil {
   179  		t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   180  	}
   181  	if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
   182  		t.Errorf("dst: got %q, wanted %q", got, want)
   183  	}
   184  }
   185  
   186  func TestWriteFullToBlocks(t *testing.T) {
   187  	srcs := makeBlocks([]byte("foo"), []byte("bar"))
   188  	var dst bytes.Buffer
   189  	w := FromIOWriter{singleByteWriter{&dst}}
   190  	n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
   191  	// WriteFullToBlocks should call into FromIOWriter => singleByteWriter
   192  	// repeatedly until srcs is exhausted.
   193  	if wantN := uint64(6); n != wantN || err != nil {
   194  		t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
   195  	}
   196  	if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
   197  		t.Errorf("dst: got %q, wanted %q", got, want)
   198  	}
   199  }