gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/secio/secio_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package secio
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"io"
    21  	"io/ioutil"
    22  	"math"
    23  	"testing"
    24  )
    25  
    26  var errEndOfBuffer = errors.New("write beyond end of buffer")
    27  
    28  // buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt.
    29  // Reads beyond the end of the buffer return io.EOF. Writes beyond the end of
    30  // the buffer return errEndOfBuffer.
    31  type buffer struct {
    32  	Bytes []byte
    33  }
    34  
    35  // ReadAt implements io.ReaderAt.ReadAt.
    36  func (b *buffer) ReadAt(dst []byte, off int64) (int, error) {
    37  	if off >= int64(len(b.Bytes)) {
    38  		return 0, io.EOF
    39  	}
    40  	n := copy(dst, b.Bytes[off:])
    41  	if n < len(dst) {
    42  		return n, io.EOF
    43  	}
    44  	return n, nil
    45  }
    46  
    47  // WriteAt implements io.WriterAt.WriteAt.
    48  func (b *buffer) WriteAt(src []byte, off int64) (int, error) {
    49  	if off >= int64(len(b.Bytes)) {
    50  		return 0, errEndOfBuffer
    51  	}
    52  	n := copy(b.Bytes[off:], src)
    53  	if n < len(src) {
    54  		return n, errEndOfBuffer
    55  	}
    56  	return n, nil
    57  }
    58  
    59  func newBufferString(s string) *buffer {
    60  	return &buffer{[]byte(s)}
    61  }
    62  
    63  func TestOffsetReader(t *testing.T) {
    64  	buf := newBufferString("foobar")
    65  	r := NewOffsetReader(buf, 3)
    66  	dst, err := ioutil.ReadAll(r)
    67  	if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
    68  		t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
    69  	}
    70  }
    71  
    72  func TestSectionReader(t *testing.T) {
    73  	buf := newBufferString("foobarbaz")
    74  	r := NewSectionReader(buf, 3, 3)
    75  	dst, err := ioutil.ReadAll(r)
    76  	if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr {
    77  		t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr)
    78  	}
    79  }
    80  
    81  func TestSectionReaderLimitOverflow(t *testing.T) {
    82  	// SectionReader behaves like OffsetReader when limit overflows int64.
    83  	buf := newBufferString("foobar")
    84  	r := NewSectionReader(buf, 3, math.MaxInt64)
    85  	dst, err := ioutil.ReadAll(r)
    86  	if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
    87  		t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
    88  	}
    89  }
    90  
    91  func TestOffsetWriter(t *testing.T) {
    92  	buf := newBufferString("ABCDEF")
    93  	w := NewOffsetWriter(buf, 3)
    94  	n, err := w.Write([]byte("foobar"))
    95  	if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
    96  		t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
    97  	}
    98  	if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
    99  		t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
   100  	}
   101  }
   102  
   103  func TestSectionWriter(t *testing.T) {
   104  	buf := newBufferString("ABCDEFGHI")
   105  	w := NewSectionWriter(buf, 3, 3)
   106  	n, err := w.Write([]byte("foobar"))
   107  	if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr {
   108  		t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   109  	}
   110  	if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) {
   111  		t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
   112  	}
   113  }
   114  
   115  func TestSectionWriterLimitOverflow(t *testing.T) {
   116  	// SectionWriter behaves like OffsetWriter when limit overflows int64.
   117  	buf := newBufferString("ABCDEF")
   118  	w := NewSectionWriter(buf, 3, math.MaxInt64)
   119  	n, err := w.Write([]byte("foobar"))
   120  	if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
   121  		t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
   122  	}
   123  	if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
   124  		t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
   125  	}
   126  }