github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/safemem/io_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package safemem 16 17 import ( 18 "bytes" 19 "io" 20 "testing" 21 ) 22 23 func makeBlocks(slices ...[]byte) []Block { 24 blocks := make([]Block, 0, len(slices)) 25 for _, s := range slices { 26 blocks = append(blocks, BlockFromSafeSlice(s)) 27 } 28 return blocks 29 } 30 31 func TestFromIOReaderFullRead(t *testing.T) { 32 r := FromIOReader{bytes.NewBufferString("foobar")} 33 dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) 34 n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) 35 if wantN := uint64(6); n != wantN || err != nil { 36 t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 37 } 38 for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { 39 if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { 40 t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) 41 } 42 } 43 } 44 45 type eofHidingReader struct { 46 Reader io.Reader 47 } 48 49 func (r eofHidingReader) Read(dst []byte) (int, error) { 50 n, err := r.Reader.Read(dst) 51 if err == io.EOF { 52 return n, nil 53 } 54 return n, err 55 } 56 57 func TestFromIOReaderPartialRead(t *testing.T) { 58 r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} 59 dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) 60 n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) 61 // FromIOReader should stop after the eofHidingReader returns (1, nil) 62 // for a 3-byte read. 63 if wantN := uint64(4); n != wantN || err != nil { 64 t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 65 } 66 for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { 67 if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { 68 t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) 69 } 70 } 71 } 72 73 type singleByteReader struct { 74 Reader io.Reader 75 } 76 77 func (r singleByteReader) Read(dst []byte) (int, error) { 78 if len(dst) == 0 { 79 return r.Reader.Read(dst) 80 } 81 return r.Reader.Read(dst[:1]) 82 } 83 84 func TestSingleByteReader(t *testing.T) { 85 r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} 86 dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) 87 n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) 88 // FromIOReader should stop after the singleByteReader returns (1, nil) 89 // for a 3-byte read. 90 if wantN := uint64(1); n != wantN || err != nil { 91 t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 92 } 93 for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { 94 if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { 95 t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) 96 } 97 } 98 } 99 100 func TestReadFullToBlocks(t *testing.T) { 101 r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} 102 dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) 103 n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) 104 // ReadFullToBlocks should call into FromIOReader => singleByteReader 105 // repeatedly until dsts is exhausted. 106 if wantN := uint64(6); n != wantN || err != nil { 107 t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 108 } 109 for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { 110 if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { 111 t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) 112 } 113 } 114 } 115 116 func TestFromIOWriterFullWrite(t *testing.T) { 117 srcs := makeBlocks([]byte("foo"), []byte("bar")) 118 var dst bytes.Buffer 119 w := FromIOWriter{&dst} 120 n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) 121 if wantN := uint64(6); n != wantN || err != nil { 122 t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 123 } 124 if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { 125 t.Errorf("dst: got %q, wanted %q", got, want) 126 } 127 } 128 129 type limitedWriter struct { 130 Writer io.Writer 131 Done int 132 Limit int 133 } 134 135 func (w *limitedWriter) Write(src []byte) (int, error) { 136 count := len(src) 137 if count > (w.Limit - w.Done) { 138 count = w.Limit - w.Done 139 } 140 n, err := w.Writer.Write(src[:count]) 141 w.Done += n 142 return n, err 143 } 144 145 func TestFromIOWriterPartialWrite(t *testing.T) { 146 srcs := makeBlocks([]byte("foo"), []byte("bar")) 147 var dst bytes.Buffer 148 w := FromIOWriter{&limitedWriter{&dst, 0, 4}} 149 n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) 150 // FromIOWriter should stop after the limitedWriter returns (1, nil) for a 151 // 3-byte write. 152 if wantN := uint64(4); n != wantN || err != nil { 153 t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 154 } 155 if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { 156 t.Errorf("dst: got %q, wanted %q", got, want) 157 } 158 } 159 160 type singleByteWriter struct { 161 Writer io.Writer 162 } 163 164 func (w singleByteWriter) Write(src []byte) (int, error) { 165 if len(src) == 0 { 166 return w.Writer.Write(src) 167 } 168 return w.Writer.Write(src[:1]) 169 } 170 171 func TestSingleByteWriter(t *testing.T) { 172 srcs := makeBlocks([]byte("foo"), []byte("bar")) 173 var dst bytes.Buffer 174 w := FromIOWriter{singleByteWriter{&dst}} 175 n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) 176 // FromIOWriter should stop after the singleByteWriter returns (1, nil) 177 // for a 3-byte write. 178 if wantN := uint64(1); n != wantN || err != nil { 179 t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 180 } 181 if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { 182 t.Errorf("dst: got %q, wanted %q", got, want) 183 } 184 } 185 186 func TestWriteFullToBlocks(t *testing.T) { 187 srcs := makeBlocks([]byte("foo"), []byte("bar")) 188 var dst bytes.Buffer 189 w := FromIOWriter{singleByteWriter{&dst}} 190 n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) 191 // WriteFullToBlocks should call into FromIOWriter => singleByteWriter 192 // repeatedly until srcs is exhausted. 193 if wantN := uint64(6); n != wantN || err != nil { 194 t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) 195 } 196 if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { 197 t.Errorf("dst: got %q, wanted %q", got, want) 198 } 199 }