github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/lib/pool/reader_writer_test.go (about)

     1  package pool
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/rclone/rclone/lib/random"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  const blockSize = 4096
    16  
    17  var rwPool = New(60*time.Second, blockSize, 2, false)
    18  
    19  // A writer that always returns an error
    20  type testWriterError struct{}
    21  
    22  var errWriteError = errors.New("write error")
    23  
    24  func (testWriterError) Write(p []byte) (n int, err error) {
    25  	return 0, errWriteError
    26  }
    27  
    28  func TestRW(t *testing.T) {
    29  	var dst []byte
    30  	var pos int64
    31  	var err error
    32  	var n int
    33  
    34  	testData := []byte("Goodness!!") // 10 bytes long
    35  
    36  	newRW := func() *RW {
    37  		rw := NewRW(rwPool)
    38  		buf := bytes.NewBuffer(testData)
    39  		nn, err := rw.ReadFrom(buf) // fill up with goodness
    40  		assert.NoError(t, err)
    41  		assert.Equal(t, int64(10), nn)
    42  		assert.Equal(t, int64(10), rw.Size())
    43  		return rw
    44  	}
    45  
    46  	close := func(rw *RW) {
    47  		assert.NoError(t, rw.Close())
    48  	}
    49  
    50  	t.Run("Empty", func(t *testing.T) {
    51  		// Test empty read
    52  		rw := NewRW(rwPool)
    53  		defer close(rw)
    54  		assert.Equal(t, int64(0), rw.Size())
    55  
    56  		dst = make([]byte, 10)
    57  		n, err = rw.Read(dst)
    58  		assert.Equal(t, io.EOF, err)
    59  		assert.Equal(t, 0, n)
    60  		assert.Equal(t, int64(0), rw.Size())
    61  	})
    62  
    63  	t.Run("Full", func(t *testing.T) {
    64  		rw := newRW()
    65  		defer close(rw)
    66  
    67  		// Test full read
    68  		dst = make([]byte, 100)
    69  		n, err = rw.Read(dst)
    70  		assert.Equal(t, io.EOF, err)
    71  		assert.Equal(t, 10, n)
    72  		assert.Equal(t, testData, dst[0:10])
    73  
    74  		// Test read EOF
    75  		n, err = rw.Read(dst)
    76  		assert.Equal(t, io.EOF, err)
    77  		assert.Equal(t, 0, n)
    78  
    79  		// Test Seek Back to start
    80  		dst = make([]byte, 10)
    81  		pos, err = rw.Seek(0, io.SeekStart)
    82  		assert.Nil(t, err)
    83  		assert.Equal(t, 0, int(pos))
    84  
    85  		// Now full read
    86  		n, err = rw.Read(dst)
    87  		assert.Nil(t, err)
    88  		assert.Equal(t, 10, n)
    89  		assert.Equal(t, testData, dst)
    90  	})
    91  
    92  	t.Run("WriteTo", func(t *testing.T) {
    93  		rw := newRW()
    94  		defer close(rw)
    95  		var b bytes.Buffer
    96  
    97  		n, err := rw.WriteTo(&b)
    98  		assert.NoError(t, err)
    99  		assert.Equal(t, int64(10), n)
   100  		assert.Equal(t, testData, b.Bytes())
   101  	})
   102  
   103  	t.Run("WriteToError", func(t *testing.T) {
   104  		rw := newRW()
   105  		defer close(rw)
   106  		w := testWriterError{}
   107  
   108  		n, err := rw.WriteTo(w)
   109  		assert.Equal(t, errWriteError, err)
   110  		assert.Equal(t, int64(0), n)
   111  	})
   112  
   113  	t.Run("Partial", func(t *testing.T) {
   114  		// Test partial read
   115  		rw := newRW()
   116  		defer close(rw)
   117  
   118  		dst = make([]byte, 5)
   119  		n, err = rw.Read(dst)
   120  		assert.Nil(t, err)
   121  		assert.Equal(t, 5, n)
   122  		assert.Equal(t, testData[0:5], dst)
   123  		n, err = rw.Read(dst)
   124  		assert.Nil(t, err)
   125  		assert.Equal(t, 5, n)
   126  		assert.Equal(t, testData[5:], dst)
   127  	})
   128  
   129  	t.Run("Seek", func(t *testing.T) {
   130  		// Test Seek
   131  		rw := newRW()
   132  		defer close(rw)
   133  
   134  		// Seek to end
   135  		pos, err = rw.Seek(10, io.SeekStart)
   136  		assert.NoError(t, err)
   137  		assert.Equal(t, int64(10), pos)
   138  
   139  		// Seek to start
   140  		pos, err = rw.Seek(0, io.SeekStart)
   141  		assert.NoError(t, err)
   142  		assert.Equal(t, int64(0), pos)
   143  
   144  		// Should not allow seek past cache index
   145  		pos, err = rw.Seek(11, io.SeekCurrent)
   146  		assert.Equal(t, errSeekPastEnd, err)
   147  		assert.Equal(t, 10, int(pos))
   148  
   149  		// Should not allow seek to negative position start
   150  		pos, err = rw.Seek(-1, io.SeekCurrent)
   151  		assert.Equal(t, errNegativeSeek, err)
   152  		assert.Equal(t, 0, int(pos))
   153  
   154  		// Should not allow seek with invalid whence
   155  		pos, err = rw.Seek(0, 3)
   156  		assert.Equal(t, errInvalidWhence, err)
   157  		assert.Equal(t, 0, int(pos))
   158  
   159  		// Should seek from index with io.SeekCurrent(1) whence
   160  		dst = make([]byte, 5)
   161  		_, _ = rw.Read(dst)
   162  		pos, err = rw.Seek(-3, io.SeekCurrent)
   163  		assert.Nil(t, err)
   164  		assert.Equal(t, 2, int(pos))
   165  		pos, err = rw.Seek(1, io.SeekCurrent)
   166  		assert.Nil(t, err)
   167  		assert.Equal(t, 3, int(pos))
   168  
   169  		// Should seek from cache end with io.SeekEnd(2) whence
   170  		pos, err = rw.Seek(-3, io.SeekEnd)
   171  		assert.Nil(t, err)
   172  		assert.Equal(t, 7, int(pos))
   173  
   174  		// Should read from seek position and past it
   175  		dst = make([]byte, 3)
   176  		n, err = io.ReadFull(rw, dst)
   177  		assert.Nil(t, err)
   178  		assert.Equal(t, 3, n)
   179  		assert.Equal(t, testData[7:10], dst)
   180  	})
   181  
   182  	t.Run("Account", func(t *testing.T) {
   183  		errBoom := errors.New("accounting error")
   184  
   185  		t.Run("Read", func(t *testing.T) {
   186  			rw := newRW()
   187  			defer close(rw)
   188  
   189  			var total int
   190  			rw.SetAccounting(func(n int) error {
   191  				total += n
   192  				return nil
   193  			})
   194  
   195  			dst = make([]byte, 3)
   196  			n, err = rw.Read(dst)
   197  			assert.Equal(t, 3, n)
   198  			assert.NoError(t, err)
   199  			assert.Equal(t, 3, total)
   200  		})
   201  
   202  		t.Run("WriteTo", func(t *testing.T) {
   203  			rw := newRW()
   204  			defer close(rw)
   205  			var b bytes.Buffer
   206  
   207  			var total int
   208  			rw.SetAccounting(func(n int) error {
   209  				total += n
   210  				return nil
   211  			})
   212  
   213  			n, err := rw.WriteTo(&b)
   214  			assert.NoError(t, err)
   215  			assert.Equal(t, 10, total)
   216  			assert.Equal(t, int64(10), n)
   217  			assert.Equal(t, testData, b.Bytes())
   218  		})
   219  
   220  		t.Run("ReadDelay", func(t *testing.T) {
   221  			rw := newRW()
   222  			defer close(rw)
   223  
   224  			var total int
   225  			rw.SetAccounting(func(n int) error {
   226  				total += n
   227  				return nil
   228  			})
   229  
   230  			rewind := func() {
   231  				_, err := rw.Seek(0, io.SeekStart)
   232  				require.NoError(t, err)
   233  			}
   234  
   235  			rw.DelayAccounting(3)
   236  
   237  			dst = make([]byte, 16)
   238  
   239  			n, err = rw.Read(dst)
   240  			assert.Equal(t, 10, n)
   241  			assert.Equal(t, io.EOF, err)
   242  			assert.Equal(t, 0, total)
   243  			rewind()
   244  
   245  			n, err = rw.Read(dst)
   246  			assert.Equal(t, 10, n)
   247  			assert.Equal(t, io.EOF, err)
   248  			assert.Equal(t, 0, total)
   249  			rewind()
   250  
   251  			n, err = rw.Read(dst)
   252  			assert.Equal(t, 10, n)
   253  			assert.Equal(t, io.EOF, err)
   254  			assert.Equal(t, 10, total)
   255  			rewind()
   256  
   257  			n, err = rw.Read(dst)
   258  			assert.Equal(t, 10, n)
   259  			assert.Equal(t, io.EOF, err)
   260  			assert.Equal(t, 20, total)
   261  			rewind()
   262  		})
   263  
   264  		t.Run("WriteToDelay", func(t *testing.T) {
   265  			rw := newRW()
   266  			defer close(rw)
   267  			var b bytes.Buffer
   268  
   269  			var total int
   270  			rw.SetAccounting(func(n int) error {
   271  				total += n
   272  				return nil
   273  			})
   274  
   275  			rw.DelayAccounting(3)
   276  
   277  			rewind := func() {
   278  				_, err := rw.Seek(0, io.SeekStart)
   279  				require.NoError(t, err)
   280  				b.Reset()
   281  			}
   282  
   283  			n, err := rw.WriteTo(&b)
   284  			assert.NoError(t, err)
   285  			assert.Equal(t, 0, total)
   286  			assert.Equal(t, int64(10), n)
   287  			assert.Equal(t, testData, b.Bytes())
   288  			rewind()
   289  
   290  			n, err = rw.WriteTo(&b)
   291  			assert.NoError(t, err)
   292  			assert.Equal(t, 0, total)
   293  			assert.Equal(t, int64(10), n)
   294  			assert.Equal(t, testData, b.Bytes())
   295  			rewind()
   296  
   297  			n, err = rw.WriteTo(&b)
   298  			assert.NoError(t, err)
   299  			assert.Equal(t, 10, total)
   300  			assert.Equal(t, int64(10), n)
   301  			assert.Equal(t, testData, b.Bytes())
   302  			rewind()
   303  
   304  			n, err = rw.WriteTo(&b)
   305  			assert.NoError(t, err)
   306  			assert.Equal(t, 20, total)
   307  			assert.Equal(t, int64(10), n)
   308  			assert.Equal(t, testData, b.Bytes())
   309  			rewind()
   310  		})
   311  
   312  		t.Run("ReadError", func(t *testing.T) {
   313  			// Test accounting errors
   314  			rw := newRW()
   315  			defer close(rw)
   316  
   317  			rw.SetAccounting(func(n int) error {
   318  				return errBoom
   319  			})
   320  
   321  			dst = make([]byte, 3)
   322  			n, err = rw.Read(dst)
   323  			assert.Equal(t, 3, n)
   324  			assert.Equal(t, errBoom, err)
   325  		})
   326  
   327  		t.Run("WriteToError", func(t *testing.T) {
   328  			rw := newRW()
   329  			defer close(rw)
   330  			rw.SetAccounting(func(n int) error {
   331  				return errBoom
   332  			})
   333  			var b bytes.Buffer
   334  
   335  			n, err := rw.WriteTo(&b)
   336  			assert.Equal(t, errBoom, err)
   337  			assert.Equal(t, int64(10), n)
   338  			assert.Equal(t, testData, b.Bytes())
   339  		})
   340  	})
   341  
   342  }
   343  
   344  // A reader to read in chunkSize chunks
   345  type testReader struct {
   346  	data      []byte
   347  	chunkSize int
   348  }
   349  
   350  // Read in chunkSize chunks
   351  func (r *testReader) Read(p []byte) (n int, err error) {
   352  	if len(r.data) == 0 {
   353  		return 0, io.EOF
   354  	}
   355  	chunkSize := r.chunkSize
   356  	if chunkSize > len(r.data) {
   357  		chunkSize = len(r.data)
   358  	}
   359  	n = copy(p, r.data[:chunkSize])
   360  	r.data = r.data[n:]
   361  	return n, nil
   362  }
   363  
   364  // A writer to write in chunkSize chunks
   365  type testWriter struct {
   366  	t         *testing.T
   367  	data      []byte
   368  	chunkSize int
   369  	buf       []byte
   370  	offset    int
   371  }
   372  
   373  // Write in chunkSize chunks
   374  func (w *testWriter) Write(p []byte) (n int, err error) {
   375  	if w.buf == nil {
   376  		w.buf = make([]byte, w.chunkSize)
   377  	}
   378  	n = copy(w.buf, p)
   379  	assert.Equal(w.t, w.data[w.offset:w.offset+n], w.buf[:n])
   380  	w.offset += n
   381  	return n, nil
   382  }
   383  
   384  func TestRWBoundaryConditions(t *testing.T) {
   385  	var accounted int
   386  	account := func(n int) error {
   387  		accounted += n
   388  		return nil
   389  	}
   390  
   391  	maxSize := 3 * blockSize
   392  	buf := []byte(random.String(maxSize))
   393  
   394  	sizes := []int{
   395  		1, 2, 3,
   396  		blockSize - 2, blockSize - 1, blockSize, blockSize + 1, blockSize + 2,
   397  		2*blockSize - 2, 2*blockSize - 1, 2 * blockSize, 2*blockSize + 1, 2*blockSize + 2,
   398  		3*blockSize - 2, 3*blockSize - 1, 3 * blockSize,
   399  	}
   400  
   401  	// Write the data in chunkSize chunks
   402  	write := func(rw *RW, data []byte, chunkSize int) {
   403  		writeData := data
   404  		for len(writeData) > 0 {
   405  			i := chunkSize
   406  			if i > len(writeData) {
   407  				i = len(writeData)
   408  			}
   409  			nn, err := rw.Write(writeData[:i])
   410  			assert.NoError(t, err)
   411  			assert.Equal(t, len(writeData[:i]), nn)
   412  			writeData = writeData[nn:]
   413  		}
   414  	}
   415  
   416  	// Write the data in chunkSize chunks using ReadFrom
   417  	readFrom := func(rw *RW, data []byte, chunkSize int) {
   418  		nn, err := rw.ReadFrom(&testReader{
   419  			data:      data,
   420  			chunkSize: chunkSize,
   421  		})
   422  		assert.NoError(t, err)
   423  		assert.Equal(t, int64(len(data)), nn)
   424  	}
   425  
   426  	// Read the data back and check it is OK in chunkSize chunks
   427  	read := func(rw *RW, data []byte, chunkSize int) {
   428  		size := len(data)
   429  		buf := make([]byte, chunkSize)
   430  		offset := 0
   431  		for {
   432  			nn, err := rw.Read(buf)
   433  			expectedRead := len(buf)
   434  			if offset+chunkSize > size {
   435  				expectedRead = size - offset
   436  				assert.Equal(t, err, io.EOF)
   437  			} else {
   438  				assert.NoError(t, err)
   439  			}
   440  			assert.Equal(t, expectedRead, nn)
   441  			assert.Equal(t, data[offset:offset+nn], buf[:nn])
   442  			offset += nn
   443  			if err == io.EOF {
   444  				break
   445  			}
   446  		}
   447  	}
   448  
   449  	// Read the data back and check it is OK in chunkSize chunks using WriteTo
   450  	writeTo := func(rw *RW, data []byte, chunkSize int) {
   451  		nn, err := rw.WriteTo(&testWriter{
   452  			t:         t,
   453  			data:      data,
   454  			chunkSize: chunkSize,
   455  		})
   456  		assert.NoError(t, err)
   457  		assert.Equal(t, int64(len(data)), nn)
   458  	}
   459  
   460  	type test struct {
   461  		name string
   462  		fn   func(*RW, []byte, int)
   463  	}
   464  
   465  	// Read and Write the data with a range of block sizes and functions
   466  	for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
   467  		t.Run(write.name, func(t *testing.T) {
   468  			for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
   469  				t.Run(read.name, func(t *testing.T) {
   470  					for _, size := range sizes {
   471  						data := buf[:size]
   472  						for _, chunkSize := range sizes {
   473  							//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
   474  							rw := NewRW(rwPool)
   475  							assert.Equal(t, int64(0), rw.Size())
   476  							accounted = 0
   477  							rw.SetAccounting(account)
   478  							assert.Equal(t, 0, accounted)
   479  							write.fn(rw, data, chunkSize)
   480  							assert.Equal(t, int64(size), rw.Size())
   481  							assert.Equal(t, 0, accounted)
   482  							read.fn(rw, data, chunkSize)
   483  							assert.NoError(t, rw.Close())
   484  							assert.Equal(t, size, accounted)
   485  						}
   486  					}
   487  				})
   488  			}
   489  		})
   490  	}
   491  }