github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/utils/iohelp/read_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package iohelp
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"errors"
    21  	"io"
    22  	"reflect"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  
    28  	"github.com/dolthub/dolt/go/libraries/utils/mathutil"
    29  	"github.com/dolthub/dolt/go/libraries/utils/osutil"
    30  	"github.com/dolthub/dolt/go/libraries/utils/test"
    31  )
    32  
    33  func TestErrPreservingReader(t *testing.T) {
    34  	tr := test.NewTestReader(32, 16)
    35  	epr := NewErrPreservingReader(tr)
    36  
    37  	read1, noErr1 := ReadNBytes(epr, 8)
    38  	read2, noErr2 := ReadNBytes(epr, 8)
    39  	read3, firstErr := ReadNBytes(epr, 8)
    40  	read4, secondErr := ReadNBytes(epr, 8)
    41  
    42  	for i := 0; i < 8; i++ {
    43  		if read1[i] != byte(i) || read2[i] != byte(i)+8 {
    44  			t.Error("Unexpected values read.")
    45  		}
    46  	}
    47  
    48  	if read3 != nil || read4 != nil {
    49  		t.Error("Unexpected read values should be nil.")
    50  	}
    51  
    52  	if noErr1 != nil || noErr2 != nil {
    53  		t.Error("Unexpected error.")
    54  	}
    55  
    56  	if firstErr == nil || secondErr == nil || epr.Err == nil {
    57  		t.Error("Expected error not received.")
    58  	} else {
    59  		first := firstErr.(*test.TestError).ErrId
    60  		second := secondErr.(*test.TestError).ErrId
    61  		preservedErrID := epr.Err.(*test.TestError).ErrId
    62  
    63  		if preservedErrID != first || preservedErrID != second {
    64  			t.Error("Error not preserved properly.")
    65  		}
    66  	}
    67  }
    68  
    69  var rlTests = []struct {
    70  	inputStr      string
    71  	expectedLines []string
    72  }{
    73  	{"line 1\nline 2\r\nline 3\n", []string{"line 1", "line 2", "line 3", ""}},
    74  	{"line 1\nline 2\r\nline 3", []string{"line 1", "line 2", "line 3"}},
    75  	{"\r\nline 1\nline 2\r\nline 3\r\r\r\n\n", []string{"", "line 1", "line 2", "line 3", "", ""}},
    76  }
    77  
    78  func TestReadReadLineFunctions(t *testing.T) {
    79  	for _, test := range rlTests {
    80  		bufferedTest := getTestReadLineClosure(test.inputStr)
    81  		unbufferedTest := getTestReadLineNoBufClosure(test.inputStr)
    82  
    83  		testReadLineFunctions(t, "buffered", test.expectedLines, bufferedTest)
    84  		testReadLineFunctions(t, "unbuffered", test.expectedLines, unbufferedTest)
    85  	}
    86  }
    87  
    88  func getTestReadLineClosure(inputStr string) func() (string, bool, error) {
    89  	r := bytes.NewReader([]byte(inputStr))
    90  	br := bufio.NewReader(r)
    91  
    92  	return func() (string, bool, error) {
    93  		return ReadLine(br)
    94  	}
    95  }
    96  
    97  func getTestReadLineNoBufClosure(inputStr string) func() (string, bool, error) {
    98  	r := bytes.NewReader([]byte(inputStr))
    99  
   100  	return func() (string, bool, error) {
   101  		return ReadLineNoBuf(r)
   102  	}
   103  }
   104  
   105  func testReadLineFunctions(t *testing.T, testType string, expected []string, rlFunc func() (string, bool, error)) {
   106  	var isDone bool
   107  	var line string
   108  	var err error
   109  
   110  	lines := make([]string, 0, len(expected))
   111  	for !isDone {
   112  		line, isDone, err = rlFunc()
   113  
   114  		if err == nil {
   115  			lines = append(lines, line)
   116  		}
   117  	}
   118  
   119  	if !reflect.DeepEqual(lines, expected) {
   120  		t.Error("Received unexpected results.")
   121  	}
   122  }
   123  
   124  var ErrClosed = errors.New("")
   125  
   126  type FixedRateDataGenerator struct {
   127  	BytesPerInterval int
   128  	Interval         time.Duration
   129  	lastRead         time.Time
   130  	closeChan        chan struct{}
   131  	dataGenerated    uint64
   132  }
   133  
   134  func NewFixedRateDataGenerator(bytesPerInterval int, interval time.Duration) *FixedRateDataGenerator {
   135  	return &FixedRateDataGenerator{
   136  		bytesPerInterval,
   137  		interval,
   138  		time.Now(),
   139  		make(chan struct{}),
   140  		0,
   141  	}
   142  }
   143  
   144  func (gen *FixedRateDataGenerator) Read(p []byte) (int, error) {
   145  	nextRead := gen.Interval - (time.Now().Sub(gen.lastRead))
   146  
   147  	select {
   148  	case <-gen.closeChan:
   149  		return 0, ErrClosed
   150  	case <-time.After(nextRead):
   151  		gen.dataGenerated += uint64(gen.BytesPerInterval)
   152  		gen.lastRead = time.Now()
   153  		return mathutil.Min(gen.BytesPerInterval, len(p)), nil
   154  	}
   155  }
   156  
   157  func (gen *FixedRateDataGenerator) Close() error {
   158  	close(gen.closeChan)
   159  	return nil
   160  }
   161  
   162  type ErroringReader struct {
   163  	Err error
   164  }
   165  
   166  func (er ErroringReader) Read(p []byte) (int, error) {
   167  	return 0, er.Err
   168  }
   169  
   170  func (er ErroringReader) Close() error {
   171  	return nil
   172  }
   173  
   174  type ReaderSizePair struct {
   175  	Reader io.ReadCloser
   176  	Size   int
   177  }
   178  
   179  type ReaderCollection struct {
   180  	ReadersAndSizes []ReaderSizePair
   181  	currIdx         int
   182  	currReaderRead  int
   183  }
   184  
   185  func NewReaderCollection(readerSizePair ...ReaderSizePair) *ReaderCollection {
   186  	if len(readerSizePair) == 0 {
   187  		panic("no readers")
   188  	}
   189  
   190  	for _, rsp := range readerSizePair {
   191  		if rsp.Size <= 0 {
   192  			panic("invalid size")
   193  		}
   194  
   195  		if rsp.Reader == nil {
   196  			panic("invalid reader")
   197  		}
   198  	}
   199  
   200  	return &ReaderCollection{readerSizePair, 0, 0}
   201  }
   202  
   203  func (rc *ReaderCollection) Read(p []byte) (int, error) {
   204  	if rc.currIdx < len(rc.ReadersAndSizes) {
   205  		currReader := rc.ReadersAndSizes[rc.currIdx].Reader
   206  		currSize := rc.ReadersAndSizes[rc.currIdx].Size
   207  		remaining := currSize - rc.currReaderRead
   208  
   209  		n, err := currReader.Read(p)
   210  
   211  		if err != nil {
   212  			return 0, err
   213  		}
   214  
   215  		if n >= remaining {
   216  			n = remaining
   217  			rc.currIdx++
   218  			rc.currReaderRead = 0
   219  		} else {
   220  			rc.currReaderRead += n
   221  		}
   222  
   223  		return n, err
   224  	}
   225  
   226  	return 0, io.EOF
   227  }
   228  
   229  func (rc *ReaderCollection) Close() error {
   230  	for _, rsp := range rc.ReadersAndSizes {
   231  		err := rsp.Reader.Close()
   232  
   233  		if err != nil {
   234  			return err
   235  		}
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  func TestReadWithMinThroughput(t *testing.T) {
   242  	t.Skip("Skipping test in all cases as it is inconsistent on Unix")
   243  	if osutil.IsWindows {
   244  		t.Skip("Skipping test as it is too inconsistent on Windows and will randomly pass or fail")
   245  	}
   246  	tests := []struct {
   247  		name          string
   248  		numBytes      int64
   249  		reader        io.ReadCloser
   250  		mtcp          MinThroughputCheckParams
   251  		expErr        bool
   252  		expThroughErr bool
   253  	}{
   254  		{
   255  			"10MB @ max(100MBps) > 50MBps",
   256  			10 * 1024 * 1024,
   257  			NewReaderCollection(
   258  				ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 10 * 1024 * 1024},
   259  			),
   260  			MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10},
   261  			false,
   262  			false,
   263  		},
   264  		{
   265  			"5MB then error",
   266  			10 * 1024 * 1024,
   267  			NewReaderCollection(
   268  				ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024},
   269  				ReaderSizePair{ErroringReader{errors.New("test err")}, 100 * 1024},
   270  				ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024},
   271  			),
   272  			MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10},
   273  			true,
   274  			false,
   275  		},
   276  		{
   277  			"5MB then slow < 50Mbps",
   278  			10 * 1024 * 1024,
   279  			NewReaderCollection(
   280  				ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024},
   281  				ReaderSizePair{NewFixedRateDataGenerator(49*1024, time.Millisecond), 5 * 1024 * 1024},
   282  			),
   283  			MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10},
   284  			false,
   285  			true,
   286  		},
   287  		{
   288  			"5MB then stops",
   289  			10 * 1024 * 1024,
   290  			NewReaderCollection(
   291  				ReaderSizePair{NewFixedRateDataGenerator(100*1024, time.Millisecond), 5 * 1024 * 1024},
   292  				ReaderSizePair{NewFixedRateDataGenerator(0, 100*time.Second), 5 * 1024 * 1024},
   293  			),
   294  			MinThroughputCheckParams{50 * 1024 * 1024, 5 * time.Millisecond, 10},
   295  			false,
   296  			true,
   297  		},
   298  	}
   299  
   300  	for _, test := range tests {
   301  		t.Run(test.name, func(t *testing.T) {
   302  			data, err := ReadWithMinThroughput(test.reader, test.numBytes, test.mtcp)
   303  
   304  			if test.expErr || test.expThroughErr {
   305  				if test.expThroughErr {
   306  					assert.Equal(t, err, ErrThroughput)
   307  				} else {
   308  					assert.Error(t, err)
   309  					assert.NotEqual(t, err, ErrThroughput)
   310  				}
   311  			} else {
   312  				assert.Equal(t, len(data), int(test.numBytes))
   313  				assert.NoError(t, err)
   314  			}
   315  		})
   316  	}
   317  }