github.com/rclone/rclone@v1.66.1-0.20240517100346-7b89735ae726/fs/operations/reopen_test.go (about)

     1  package operations
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"testing"
     8  
     9  	"github.com/rclone/rclone/fs"
    10  	"github.com/rclone/rclone/fs/hash"
    11  	"github.com/rclone/rclone/fstest/mockobject"
    12  	"github.com/rclone/rclone/lib/pool"
    13  	"github.com/rclone/rclone/lib/readers"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  // check interfaces
    19  var (
    20  	_ io.ReadSeekCloser      = (*ReOpen)(nil)
    21  	_ pool.DelayAccountinger = (*ReOpen)(nil)
    22  )
    23  
    24  var errorTestError = errors.New("test error")
    25  
    26  // this is a wrapper for a mockobject with a custom Open function
    27  //
    28  // breaks indicate the number of bytes to read before returning an
    29  // error
    30  type reOpenTestObject struct {
    31  	fs.Object
    32  	t           *testing.T
    33  	wantStart   int64
    34  	breaks      []int64
    35  	unknownSize bool
    36  }
    37  
    38  // Open opens the file for read.  Call Close() on the returned io.ReadCloser
    39  //
    40  // This will break after reading the number of bytes in breaks
    41  func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) {
    42  	// Lots of backends do this - make sure it works as it modifies options
    43  	fs.FixRangeOption(options, o.Size())
    44  	gotHash := false
    45  	gotRange := false
    46  	startPos := int64(0)
    47  	for _, option := range options {
    48  		switch x := option.(type) {
    49  		case *fs.HashesOption:
    50  			gotHash = true
    51  		case *fs.RangeOption:
    52  			gotRange = true
    53  			startPos = x.Start
    54  			if o.unknownSize {
    55  				assert.Equal(o.t, int64(-1), x.End)
    56  			}
    57  		case *fs.SeekOption:
    58  			startPos = x.Offset
    59  		}
    60  	}
    61  	assert.Equal(o.t, o.wantStart, startPos)
    62  	// Check if ranging, mustn't have hash if offset != 0
    63  	if gotHash && gotRange {
    64  		assert.Equal(o.t, int64(0), startPos)
    65  	}
    66  	rc, err := o.Object.Open(ctx, options...)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	if len(o.breaks) > 0 {
    71  		// Pop a breakpoint off
    72  		N := o.breaks[0]
    73  		o.breaks = o.breaks[1:]
    74  		o.wantStart += N
    75  		// If 0 then return an error immediately
    76  		if N == 0 {
    77  			return nil, errorTestError
    78  		}
    79  		// Read N bytes then an error
    80  		r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError})
    81  		// Wrap with Close in a new readCloser
    82  		rc = readCloser{Reader: r, Closer: rc}
    83  	}
    84  	return rc, nil
    85  }
    86  
    87  func TestReOpen(t *testing.T) {
    88  	for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} {
    89  		t.Run(testName, func(t *testing.T) {
    90  			// Contents for the mock object
    91  			var (
    92  				reOpenTestcontents = []byte("0123456789")
    93  				expectedRead       = reOpenTestcontents
    94  				rangeOption        *fs.RangeOption
    95  				seekOption         *fs.SeekOption
    96  				unknownSize        = false
    97  			)
    98  			switch testName {
    99  			case "Normal":
   100  			case "WithRangeOption":
   101  				rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive
   102  				expectedRead = reOpenTestcontents[1:8]
   103  			case "WithSeekOption":
   104  				seekOption = &fs.SeekOption{Offset: 2}
   105  				expectedRead = reOpenTestcontents[2:]
   106  			case "UnknownSize":
   107  				rangeOption = &fs.RangeOption{Start: 1, End: -1}
   108  				expectedRead = reOpenTestcontents[1:]
   109  				unknownSize = true
   110  			default:
   111  				panic("bad test name")
   112  			}
   113  
   114  			// Start the test with the given breaks
   115  			testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) {
   116  				srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone)
   117  				srcOrig.SetUnknownSize(unknownSize)
   118  				src := &reOpenTestObject{
   119  					Object:      srcOrig,
   120  					t:           t,
   121  					breaks:      breaks,
   122  					unknownSize: unknownSize,
   123  				}
   124  				opts := []fs.OpenOption{}
   125  				if rangeOption == nil && seekOption == nil {
   126  					opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)})
   127  				}
   128  				if rangeOption != nil {
   129  					opts = append(opts, rangeOption)
   130  					src.wantStart = rangeOption.Start
   131  				}
   132  				if seekOption != nil {
   133  					opts = append(opts, seekOption)
   134  					src.wantStart = seekOption.Offset
   135  				}
   136  				rc, err := NewReOpen(context.Background(), src, maxRetries, opts...)
   137  				return rc, src, err
   138  			}
   139  
   140  			t.Run("Basics", func(t *testing.T) {
   141  				// open
   142  				h, _, err := testReOpen(nil, 10)
   143  				assert.NoError(t, err)
   144  
   145  				// Check contents read correctly
   146  				got, err := io.ReadAll(h)
   147  				assert.NoError(t, err)
   148  				assert.Equal(t, expectedRead, got)
   149  
   150  				// Check read after end
   151  				var buf = make([]byte, 1)
   152  				n, err := h.Read(buf)
   153  				assert.Equal(t, 0, n)
   154  				assert.Equal(t, io.EOF, err)
   155  
   156  				// Rewind the stream
   157  				_, err = h.Seek(0, io.SeekStart)
   158  				require.NoError(t, err)
   159  
   160  				// Check contents read correctly
   161  				got, err = io.ReadAll(h)
   162  				assert.NoError(t, err)
   163  				assert.Equal(t, expectedRead, got)
   164  
   165  				// Check close
   166  				assert.NoError(t, h.Close())
   167  
   168  				// Check double close
   169  				assert.Equal(t, errFileClosed, h.Close())
   170  
   171  				// Check read after close
   172  				n, err = h.Read(buf)
   173  				assert.Equal(t, 0, n)
   174  				assert.Equal(t, errFileClosed, err)
   175  			})
   176  
   177  			t.Run("ErrorAtStart", func(t *testing.T) {
   178  				// open with immediate breaking
   179  				h, _, err := testReOpen([]int64{0}, 10)
   180  				assert.Equal(t, errorTestError, err)
   181  				assert.Nil(t, h)
   182  			})
   183  
   184  			t.Run("WithErrors", func(t *testing.T) {
   185  				// open with a few break points but less than the max
   186  				h, _, err := testReOpen([]int64{2, 1, 3}, 10)
   187  				assert.NoError(t, err)
   188  
   189  				// check contents
   190  				got, err := io.ReadAll(h)
   191  				assert.NoError(t, err)
   192  				assert.Equal(t, expectedRead, got)
   193  
   194  				// check close
   195  				assert.NoError(t, h.Close())
   196  			})
   197  
   198  			t.Run("TooManyErrors", func(t *testing.T) {
   199  				// open with a few break points but >= the max
   200  				h, _, err := testReOpen([]int64{2, 1, 3}, 3)
   201  				assert.NoError(t, err)
   202  
   203  				// check contents
   204  				got, err := io.ReadAll(h)
   205  				assert.Equal(t, errorTestError, err)
   206  				assert.Equal(t, expectedRead[:6], got)
   207  
   208  				// check old error is returned
   209  				var buf = make([]byte, 1)
   210  				n, err := h.Read(buf)
   211  				assert.Equal(t, 0, n)
   212  				assert.Equal(t, errTooManyTries, err)
   213  
   214  				// Check close
   215  				assert.Equal(t, errFileClosed, h.Close())
   216  			})
   217  
   218  			t.Run("Seek", func(t *testing.T) {
   219  				// open
   220  				h, src, err := testReOpen([]int64{2, 1, 3}, 10)
   221  				assert.NoError(t, err)
   222  
   223  				// Seek to end
   224  				pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart)
   225  				assert.NoError(t, err)
   226  				assert.Equal(t, int64(len(expectedRead)), pos)
   227  
   228  				// Seek to start
   229  				pos, err = h.Seek(0, io.SeekStart)
   230  				assert.NoError(t, err)
   231  				assert.Equal(t, int64(0), pos)
   232  
   233  				// Should not allow seek past end
   234  				pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent)
   235  				if !unknownSize {
   236  					assert.Equal(t, errSeekPastEnd, err)
   237  					assert.Equal(t, len(expectedRead), int(pos))
   238  				} else {
   239  					assert.Equal(t, nil, err)
   240  					assert.Equal(t, len(expectedRead)+1, int(pos))
   241  
   242  					// Seek back to start to get tests in sync
   243  					pos, err = h.Seek(0, io.SeekStart)
   244  					assert.NoError(t, err)
   245  					assert.Equal(t, int64(0), pos)
   246  				}
   247  
   248  				// Should not allow seek to negative position start
   249  				pos, err = h.Seek(-1, io.SeekCurrent)
   250  				assert.Equal(t, errNegativeSeek, err)
   251  				assert.Equal(t, 0, int(pos))
   252  
   253  				// Should not allow seek with invalid whence
   254  				pos, err = h.Seek(0, 3)
   255  				assert.Equal(t, errInvalidWhence, err)
   256  				assert.Equal(t, 0, int(pos))
   257  
   258  				// check read
   259  				dst := make([]byte, 5)
   260  				n, err := h.Read(dst)
   261  				assert.Nil(t, err)
   262  				assert.Equal(t, 5, n)
   263  				assert.Equal(t, expectedRead[:5], dst)
   264  
   265  				// Test io.SeekCurrent
   266  				pos, err = h.Seek(-3, io.SeekCurrent)
   267  				assert.Nil(t, err)
   268  				assert.Equal(t, 2, int(pos))
   269  
   270  				// Reset the start after a seek, taking into account the offset
   271  				setWantStart := func(x int64) {
   272  					src.wantStart = x
   273  					if rangeOption != nil {
   274  						src.wantStart += rangeOption.Start
   275  					} else if seekOption != nil {
   276  						src.wantStart += seekOption.Offset
   277  					}
   278  				}
   279  
   280  				// check read
   281  				setWantStart(2)
   282  				n, err = h.Read(dst)
   283  				assert.Nil(t, err)
   284  				assert.Equal(t, 5, n)
   285  				assert.Equal(t, expectedRead[2:7], dst)
   286  
   287  				pos, err = h.Seek(-2, io.SeekCurrent)
   288  				assert.Nil(t, err)
   289  				assert.Equal(t, 5, int(pos))
   290  
   291  				// Test io.SeekEnd
   292  				pos, err = h.Seek(-3, io.SeekEnd)
   293  				if !unknownSize {
   294  					assert.Nil(t, err)
   295  					assert.Equal(t, len(expectedRead)-3, int(pos))
   296  				} else {
   297  					assert.Equal(t, errBadEndSeek, err)
   298  					assert.Equal(t, 0, int(pos))
   299  
   300  					// sync
   301  					pos, err = h.Seek(1, io.SeekCurrent)
   302  					assert.Nil(t, err)
   303  					assert.Equal(t, 6, int(pos))
   304  				}
   305  
   306  				// check read
   307  				dst = make([]byte, 3)
   308  				setWantStart(int64(len(expectedRead) - 3))
   309  				n, err = h.Read(dst)
   310  				assert.Nil(t, err)
   311  				assert.Equal(t, 3, n)
   312  				assert.Equal(t, expectedRead[len(expectedRead)-3:], dst)
   313  
   314  				// check close
   315  				assert.NoError(t, h.Close())
   316  				_, err = h.Seek(0, io.SeekCurrent)
   317  				assert.Equal(t, errFileClosed, err)
   318  			})
   319  
   320  			t.Run("AccountRead", func(t *testing.T) {
   321  				h, _, err := testReOpen(nil, 10)
   322  				assert.NoError(t, err)
   323  
   324  				var total int
   325  				h.SetAccounting(func(n int) error {
   326  					total += n
   327  					return nil
   328  				})
   329  
   330  				dst := make([]byte, 3)
   331  				n, err := h.Read(dst)
   332  				assert.Equal(t, 3, n)
   333  				assert.NoError(t, err)
   334  				assert.Equal(t, 3, total)
   335  			})
   336  
   337  			t.Run("AccountReadDelay", func(t *testing.T) {
   338  				h, _, err := testReOpen(nil, 10)
   339  				assert.NoError(t, err)
   340  
   341  				var total int
   342  				h.SetAccounting(func(n int) error {
   343  					total += n
   344  					return nil
   345  				})
   346  
   347  				rewind := func() {
   348  					_, err := h.Seek(0, io.SeekStart)
   349  					require.NoError(t, err)
   350  				}
   351  
   352  				h.DelayAccounting(3)
   353  
   354  				dst := make([]byte, 16)
   355  
   356  				n, err := h.Read(dst)
   357  				assert.Equal(t, len(expectedRead), n)
   358  				assert.Equal(t, io.EOF, err)
   359  				assert.Equal(t, 0, total)
   360  				rewind()
   361  
   362  				n, err = h.Read(dst)
   363  				assert.Equal(t, len(expectedRead), n)
   364  				assert.Equal(t, io.EOF, err)
   365  				assert.Equal(t, 0, total)
   366  				rewind()
   367  
   368  				n, err = h.Read(dst)
   369  				assert.Equal(t, len(expectedRead), n)
   370  				assert.Equal(t, io.EOF, err)
   371  				assert.Equal(t, len(expectedRead), total)
   372  				rewind()
   373  
   374  				n, err = h.Read(dst)
   375  				assert.Equal(t, len(expectedRead), n)
   376  				assert.Equal(t, io.EOF, err)
   377  				assert.Equal(t, 2*len(expectedRead), total)
   378  				rewind()
   379  			})
   380  
   381  			t.Run("AccountReadError", func(t *testing.T) {
   382  				// Test accounting errors
   383  				h, _, err := testReOpen(nil, 10)
   384  				assert.NoError(t, err)
   385  
   386  				h.SetAccounting(func(n int) error {
   387  					return errorTestError
   388  				})
   389  
   390  				dst := make([]byte, 3)
   391  				n, err := h.Read(dst)
   392  				assert.Equal(t, 3, n)
   393  				assert.Equal(t, errorTestError, err)
   394  			})
   395  		})
   396  	}
   397  }