github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/secio/secio_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package secio 16 17 import ( 18 "bytes" 19 "errors" 20 "io" 21 "io/ioutil" 22 "math" 23 "testing" 24 ) 25 26 var errEndOfBuffer = errors.New("write beyond end of buffer") 27 28 // buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt. 29 // Reads beyond the end of the buffer return io.EOF. Writes beyond the end of 30 // the buffer return errEndOfBuffer. 31 type buffer struct { 32 Bytes []byte 33 } 34 35 // ReadAt implements io.ReaderAt.ReadAt. 36 func (b *buffer) ReadAt(dst []byte, off int64) (int, error) { 37 if off >= int64(len(b.Bytes)) { 38 return 0, io.EOF 39 } 40 n := copy(dst, b.Bytes[off:]) 41 if n < len(dst) { 42 return n, io.EOF 43 } 44 return n, nil 45 } 46 47 // WriteAt implements io.WriterAt.WriteAt. 48 func (b *buffer) WriteAt(src []byte, off int64) (int, error) { 49 if off >= int64(len(b.Bytes)) { 50 return 0, errEndOfBuffer 51 } 52 n := copy(b.Bytes[off:], src) 53 if n < len(src) { 54 return n, errEndOfBuffer 55 } 56 return n, nil 57 } 58 59 func newBufferString(s string) *buffer { 60 return &buffer{[]byte(s)} 61 } 62 63 func TestOffsetReader(t *testing.T) { 64 buf := newBufferString("foobar") 65 r := NewOffsetReader(buf, 3) 66 dst, err := ioutil.ReadAll(r) 67 if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { 68 t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) 69 } 70 } 71 72 func TestSectionReader(t *testing.T) { 73 buf := newBufferString("foobarbaz") 74 r := NewSectionReader(buf, 3, 3) 75 dst, err := ioutil.ReadAll(r) 76 if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr { 77 t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr) 78 } 79 } 80 81 func TestSectionReaderLimitOverflow(t *testing.T) { 82 // SectionReader behaves like OffsetReader when limit overflows int64. 83 buf := newBufferString("foobar") 84 r := NewSectionReader(buf, 3, math.MaxInt64) 85 dst, err := ioutil.ReadAll(r) 86 if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { 87 t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) 88 } 89 } 90 91 func TestOffsetWriter(t *testing.T) { 92 buf := newBufferString("ABCDEF") 93 w := NewOffsetWriter(buf, 3) 94 n, err := w.Write([]byte("foobar")) 95 if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { 96 t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 97 } 98 if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { 99 t.Errorf("buf.Bytes: got %q, wanted %q", got, want) 100 } 101 } 102 103 func TestSectionWriter(t *testing.T) { 104 buf := newBufferString("ABCDEFGHI") 105 w := NewSectionWriter(buf, 3, 3) 106 n, err := w.Write([]byte("foobar")) 107 if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr { 108 t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 109 } 110 if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) { 111 t.Errorf("buf.Bytes: got %q, wanted %q", got, want) 112 } 113 } 114 115 func TestSectionWriterLimitOverflow(t *testing.T) { 116 // SectionWriter behaves like OffsetWriter when limit overflows int64. 117 buf := newBufferString("ABCDEF") 118 w := NewSectionWriter(buf, 3, math.MaxInt64) 119 n, err := w.Write([]byte("foobar")) 120 if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { 121 t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) 122 } 123 if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { 124 t.Errorf("buf.Bytes: got %q, wanted %q", got, want) 125 } 126 }