github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/utils/iohelp/read_test.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package iohelp 16 17 import ( 18 "bufio" 19 "bytes" 20 "errors" 21 "io" 22 "reflect" 23 "testing" 24 "time" 25 26 "github.com/stretchr/testify/assert" 27 28 "github.com/dolthub/dolt/go/libraries/utils/mathutil" 29 "github.com/dolthub/dolt/go/libraries/utils/osutil" 30 "github.com/dolthub/dolt/go/libraries/utils/test" 31 ) 32 33 func TestErrPreservingReader(t *testing.T) { 34 tr := test.NewTestReader(32, 16) 35 epr := NewErrPreservingReader(tr) 36 37 read1, noErr1 := ReadNBytes(epr, 8) 38 read2, noErr2 := ReadNBytes(epr, 8) 39 read3, firstErr := ReadNBytes(epr, 8) 40 read4, secondErr := ReadNBytes(epr, 8) 41 42 for i := 0; i < 8; i++ { 43 if read1[i] != byte(i) || read2[i] != byte(i)+8 { 44 t.Error("Unexpected values read.") 45 } 46 } 47 48 if read3 != nil || read4 != nil { 49 t.Error("Unexpected read values should be nil.") 50 } 51 52 if noErr1 != nil || noErr2 != nil { 53 t.Error("Unexpected error.") 54 } 55 56 if firstErr == nil || secondErr == nil || epr.Err == nil { 57 t.Error("Expected error not received.") 58 } else { 59 first := firstErr.(*test.TestError).ErrId 60 second := secondErr.(*test.TestError).ErrId 61 preservedErrID := epr.Err.(*test.TestError).ErrId 62 63 if preservedErrID != first || preservedErrID != second { 64 t.Error("Error not preserved properly.") 65 } 66 } 67 } 68 69 var rlTests = []struct { 70 inputStr string 71 expectedLines []string 72 }{ 73 {"line 1\nline 2\r\nline 3\n", []string{"line 1", "line 2", "line 3", ""}}, 74 {"line 1\nline 2\r\nline 3", []string{"line 1", "line 2", "line 3"}}, 75 {"\r\nline 1\nline 2\r\nline 3\r\r\r\n\n", []string{"", "line 1", "line 2", "line 3", "", ""}}, 76 } 77 78 func TestReadReadLineFunctions(t *testing.T) { 79 for _, test := range rlTests { 80 bufferedTest := getTestReadLineClosure(test.inputStr) 81 unbufferedTest := getTestReadLineNoBufClosure(test.inputStr) 82 83 testReadLineFunctions(t, "buffered", test.expectedLines, bufferedTest) 84 testReadLineFunctions(t, "unbuffered", test.expectedLines, unbufferedTest) 85 } 86 } 87 88 func getTestReadLineClosure(inputStr string) func() (string, bool, error) { 89 r := bytes.NewReader([]byte(inputStr)) 90 br := bufio.NewReader(r) 91 92 return func() (string, bool, error) { 93 return ReadLine(br) 94 } 95 } 96 97 func getTestReadLineNoBufClosure(inputStr string) func() (string, bool, error) { 98 r := bytes.NewReader([]byte(inputStr)) 99 100 return func() (string, bool, error) { 101 return ReadLineNoBuf(r) 102 } 103 } 104 105 func testReadLineFunctions(t *testing.T, testType string, expected []string, rlFunc func() (string, bool, error)) { 106 var isDone bool 107 var line string 108 var err error 109 110 lines := make([]string, 0, len(expected)) 111 for !isDone { 112 line, isDone, err = rlFunc() 113 114 if err == nil { 115 lines = append(lines, line) 116 } 117 } 118 119 if !reflect.DeepEqual(lines, expected) { 120 t.Error("Received unexpected results.") 121 } 122 } 123 124 var ErrClosed = errors.New("") 125 126 type FixedRateDataGenerator struct { 127 BytesPerInterval int 128 Interval time.Duration 129 lastRead time.Time 130 closeChan chan struct{} 131 dataGenerated uint64 132 } 133 134 func NewFixedRateDataGenerator(bytesPerInterval int, interval time.Duration) *FixedRateDataGenerator { 135 return &FixedRateDataGenerator{ 136 bytesPerInterval, 137 interval, 138 time.Now(), 139 make(chan struct{}), 140 0, 141 } 142 } 143 144 func (gen *FixedRateDataGenerator) Read(p []byte) (int, error) { 145 nextRead := gen.Interval - (time.Now().Sub(gen.lastRead)) 146 147 select { 148 case <-gen.closeChan: 149 return 0, ErrClosed 150 case <-time.After(nextRead): 151 gen.dataGenerated += uint64(gen.BytesPerInterval) 152 gen.lastRead = time.Now() 153 return mathutil.Min(gen.BytesPerInterval, len(p)), nil 154 } 155 } 156 157 func (gen *FixedRateDataGenerator) Close() error { 158 close(gen.closeChan) 159 return nil 160 } 161 162 type ErroringReader struct { 163 Err error 164 } 165 166 func (er ErroringReader) Read(p []byte) (int, error) { 167 return 0, er.Err 168 } 169 170 func (er ErroringReader) Close() error { 171 return nil 172 } 173 174 type ReaderSizePair struct { 175 Reader io.ReadCloser 176 Size int 177 } 178 179 type ReaderCollection struct { 180 ReadersAndSizes []ReaderSizePair 181 currIdx int 182 currReaderRead int 183 } 184 185 func NewReaderCollection(readerSizePair ...ReaderSizePair) *ReaderCollection { 186 if len(readerSizePair) == 0 { 187 panic("no readers") 188 } 189 190 for _, rsp := range readerSizePair { 191 if rsp.Size <= 0 { 192 panic("invalid size") 193 } 194 195 if rsp.Reader == nil { 196 panic("invalid reader") 197 } 198 } 199 200 return &ReaderCollection{readerSizePair, 0, 0} 201 } 202 203 func (rc *ReaderCollection) Read(p []byte) (int, error) { 204 if rc.currIdx < len(rc.ReadersAndSizes) { 205 currReader := rc.ReadersAndSizes[rc.currIdx].Reader 206 currSize := rc.ReadersAndSizes[rc.currIdx].Size 207 remaining := currSize - rc.currReaderRead 208 209 n, err := currReader.Read(p) 210 211 if err != nil { 212 return 0, err 213 } 214 215 if n >= remaining { 216 n = remaining 217 rc.currIdx++ 218 rc.currReaderRead = 0 219 } else { 220 rc.currReaderRead += n 221 } 222 223 return n, err 224 } 225 226 return 0, io.EOF 227 } 228 229 func (rc *ReaderCollection) Close() error { 230 for _, rsp := range rc.ReadersAndSizes { 231 err := rsp.Reader.Close() 232 233 if err != nil { 234 return err 235 } 236 } 237 238 return nil 239 } 240 241 func TestReadWithMinThroughput(t *testing.T) { 242 t.Skip("Skipping test in all cases as it is inconsistent on Unix") 243 if osutil.IsWindows { 244 t.Skip("Skipping test as it is too inconsistent on Windows and will randomly pass or fail") 245 } 246 tests := []struct { 247 name string 248 numBytes int64 249 reader io.ReadCloser 250 mtcp MinThroughputCheckParams 251 expErr bool 252 expThroughErr bool 253 }{ 254 { 255 "10MB @ max(100MBps) > 50MBps", 256 10 * 1024 * 1024, 257 NewReaderCollection( 258 ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 10 * 1024 * 1024}, 259 ), 260 MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10}, 261 false, 262 false, 263 }, 264 { 265 "5MB then error", 266 10 * 1024 * 1024, 267 NewReaderCollection( 268 ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024}, 269 ReaderSizePair{ErroringReader{errors.New("test err")}, 100 * 1024}, 270 ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024}, 271 ), 272 MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10}, 273 true, 274 false, 275 }, 276 { 277 "5MB then slow < 50Mbps", 278 10 * 1024 * 1024, 279 NewReaderCollection( 280 ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024}, 281 ReaderSizePair{NewFixedRateDataGenerator(49*1024, time.Millisecond), 5 * 1024 * 1024}, 282 ), 283 MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10}, 284 false, 285 true, 286 }, 287 { 288 "5MB then stops", 289 10 * 1024 * 1024, 290 NewReaderCollection( 291 ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024}, 292 ReaderSizePair{NewFixedRateDataGenerator(0, 100*time.Second), 5 * 1024 * 1024}, 293 ), 294 MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10}, 295 false, 296 true, 297 }, 298 } 299 300 for _, test := range tests { 301 t.Run(test.name, func(t *testing.T) { 302 data, err := ReadWithMinThroughput(test.reader, test.numBytes, test.mtcp) 303 304 if test.expErr || test.expThroughErr { 305 if test.expThroughErr { 306 assert.Equal(t, err, ErrThroughput) 307 } else { 308 assert.Error(t, err) 309 assert.NotEqual(t, err, ErrThroughput) 310 } 311 } else { 312 assert.Equal(t, len(data), int(test.numBytes)) 313 assert.NoError(t, err) 314 } 315 }) 316 } 317 }