aboutsummaryrefslogtreecommitdiff
path: root/fs/operations/reopen_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'fs/operations/reopen_test.go')
-rw-r--r--fs/operations/reopen_test.go429
1 files changed, 429 insertions, 0 deletions
diff --git a/fs/operations/reopen_test.go b/fs/operations/reopen_test.go
new file mode 100644
index 0000000..19f6ca3
--- /dev/null
+++ b/fs/operations/reopen_test.go
@@ -0,0 +1,429 @@
+package operations
+
+import (
+ "context"
+ "errors"
+ "io"
+ "testing"
+
+ "github.com/rclone/rclone/fs"
+ "github.com/rclone/rclone/fs/hash"
+ "github.com/rclone/rclone/fstest/mockobject"
+ "github.com/rclone/rclone/lib/pool"
+ "github.com/rclone/rclone/lib/readers"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// check interfaces
+var (
+ _ io.ReadSeekCloser = (*ReOpen)(nil)
+ _ io.ReaderAt = (*ReOpen)(nil)
+ _ pool.DelayAccountinger = (*ReOpen)(nil)
+)
+
+var errorTestError = errors.New("test error")
+
+// this is a wrapper for a mockobject with a custom Open function
+//
+// breaks indicate the number of bytes to read before returning an
+// error
+type reOpenTestObject struct {
+ fs.Object
+ t *testing.T
+ wantStart int64
+ breaks []int64
+ unknownSize bool
+}
+
+// Open opens the file for read. Call Close() on the returned io.ReadCloser
+//
+// This will break after reading the number of bytes in breaks
+func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) {
+ // Lots of backends do this - make sure it works as it modifies options
+ fs.FixRangeOption(options, o.Size())
+ gotHash := false
+ gotRange := false
+ startPos := int64(0)
+ for _, option := range options {
+ switch x := option.(type) {
+ case *fs.HashesOption:
+ gotHash = true
+ case *fs.RangeOption:
+ gotRange = true
+ startPos = x.Start
+ if o.unknownSize {
+ assert.Equal(o.t, int64(-1), x.End)
+ }
+ case *fs.SeekOption:
+ startPos = x.Offset
+ }
+ }
+ assert.Equal(o.t, o.wantStart, startPos)
+ // Check if ranging, mustn't have hash if offset != 0
+ if gotHash && gotRange {
+ assert.Equal(o.t, int64(0), startPos)
+ }
+ rc, err := o.Object.Open(ctx, options...)
+ if err != nil {
+ return nil, err
+ }
+ if len(o.breaks) > 0 {
+ // Pop a breakpoint off
+ N := o.breaks[0]
+ o.breaks = o.breaks[1:]
+ o.wantStart += N
+ // If 0 then return an error immediately
+ if N == 0 {
+ return nil, errorTestError
+ }
+ // Read N bytes then an error
+ r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError})
+ // Wrap with Close in a new readCloser
+ rc = readCloser{Reader: r, Closer: rc}
+ }
+ return rc, nil
+}
+
+func TestReOpen(t *testing.T) {
+ for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} {
+ t.Run(testName, func(t *testing.T) {
+ // Contents for the mock object
+ var (
+ reOpenTestcontents = []byte("0123456789")
+ expectedRead = reOpenTestcontents
+ rangeOption *fs.RangeOption
+ seekOption *fs.SeekOption
+ unknownSize = false
+ )
+ switch testName {
+ case "Normal":
+ case "WithRangeOption":
+ rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive
+ expectedRead = reOpenTestcontents[1:8]
+ case "WithSeekOption":
+ seekOption = &fs.SeekOption{Offset: 2}
+ expectedRead = reOpenTestcontents[2:]
+ case "UnknownSize":
+ rangeOption = &fs.RangeOption{Start: 1, End: -1}
+ expectedRead = reOpenTestcontents[1:]
+ unknownSize = true
+ default:
+ panic("bad test name")
+ }
+
+ // Start the test with the given breaks
+ testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) {
+ srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone)
+ srcOrig.SetUnknownSize(unknownSize)
+ src := &reOpenTestObject{
+ Object: srcOrig,
+ t: t,
+ breaks: breaks,
+ unknownSize: unknownSize,
+ }
+ opts := []fs.OpenOption{}
+ if rangeOption == nil && seekOption == nil {
+ opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)})
+ }
+ if rangeOption != nil {
+ opts = append(opts, rangeOption)
+ src.wantStart = rangeOption.Start
+ }
+ if seekOption != nil {
+ opts = append(opts, seekOption)
+ src.wantStart = seekOption.Offset
+ }
+ rc, err := NewReOpen(context.Background(), src, maxRetries, opts...)
+ return rc, src, err
+ }
+
+ // Reset the start after a seek, taking into account the offset
+ setWantStart := func(src *reOpenTestObject, x int64) {
+ src.wantStart = x
+ if rangeOption != nil {
+ src.wantStart += rangeOption.Start
+ } else if seekOption != nil {
+ src.wantStart += seekOption.Offset
+ }
+ }
+
+ t.Run("Basics", func(t *testing.T) {
+ // open
+ h, _, err := testReOpen(nil, 10)
+ assert.NoError(t, err)
+
+ // Check contents read correctly
+ got, err := io.ReadAll(h)
+ assert.NoError(t, err)
+ assert.Equal(t, expectedRead, got)
+
+ // Check read after end
+ var buf = make([]byte, 1)
+ n, err := h.Read(buf)
+ assert.Equal(t, 0, n)
+ assert.Equal(t, io.EOF, err)
+
+ // Rewind the stream
+ _, err = h.Seek(0, io.SeekStart)
+ require.NoError(t, err)
+
+ // Check contents read correctly
+ got, err = io.ReadAll(h)
+ assert.NoError(t, err)
+ assert.Equal(t, expectedRead, got)
+
+ // Check close
+ assert.NoError(t, h.Close())
+
+ // Check double close
+ assert.Equal(t, errFileClosed, h.Close())
+
+ // Check read after close
+ n, err = h.Read(buf)
+ assert.Equal(t, 0, n)
+ assert.Equal(t, errFileClosed, err)
+ })
+
+ t.Run("ErrorAtStart", func(t *testing.T) {
+ // open with immediate breaking
+ h, _, err := testReOpen([]int64{0}, 10)
+ assert.Equal(t, errorTestError, err)
+ assert.Nil(t, h)
+ })
+
+ t.Run("WithErrors", func(t *testing.T) {
+ // open with a few break points but less than the max
+ h, _, err := testReOpen([]int64{2, 1, 3}, 10)
+ assert.NoError(t, err)
+
+ // check contents
+ got, err := io.ReadAll(h)
+ assert.NoError(t, err)
+ assert.Equal(t, expectedRead, got)
+
+ // check close
+ assert.NoError(t, h.Close())
+ })
+
+ t.Run("TooManyErrors", func(t *testing.T) {
+ // open with a few break points but >= the max
+ h, _, err := testReOpen([]int64{2, 1, 3}, 3)
+ assert.NoError(t, err)
+
+ // check contents
+ got, err := io.ReadAll(h)
+ assert.Equal(t, errorTestError, err)
+ assert.Equal(t, expectedRead[:6], got)
+
+ // check old error is returned
+ var buf = make([]byte, 1)
+ n, err := h.Read(buf)
+ assert.Equal(t, 0, n)
+ assert.Equal(t, errTooManyTries, err)
+
+ // Check close
+ assert.Equal(t, errFileClosed, h.Close())
+ })
+
+ t.Run("ReadAt", func(t *testing.T) {
+ // open
+ h, src, err := testReOpen([]int64{2, 1, 3}, 10)
+ assert.NoError(t, err)
+
+ buf := make([]byte, 5)
+
+ // Read at 0
+ n, err := h.ReadAt(buf, 0)
+ require.NoError(t, err)
+ assert.Equal(t, 5, n)
+ assert.Equal(t, expectedRead[:n], buf[:n])
+
+ // Read at 1
+ setWantStart(src, 1)
+ n, err = h.ReadAt(buf[:3], 1)
+ require.NoError(t, err)
+ assert.Equal(t, 3, n)
+ assert.Equal(t, expectedRead[1:n+1], buf[:n])
+
+ // check position unchanged
+ pos, err := h.Seek(0, io.SeekCurrent)
+ require.NoError(t, err)
+ assert.Equal(t, int64(0), pos)
+
+ // check close
+ assert.NoError(t, h.Close())
+ _, err = h.Seek(0, io.SeekCurrent)
+ assert.Equal(t, errFileClosed, err)
+ })
+
+ t.Run("Seek", func(t *testing.T) {
+ // open
+ h, src, err := testReOpen([]int64{2, 1, 3}, 10)
+ assert.NoError(t, err)
+
+ // Seek to end
+ pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(len(expectedRead)), pos)
+
+ // Seek to start
+ pos, err = h.Seek(0, io.SeekStart)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), pos)
+
+ // Should not allow seek past end
+ pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent)
+ if !unknownSize {
+ assert.Equal(t, errSeekPastEnd, err)
+ assert.Equal(t, len(expectedRead), int(pos))
+ } else {
+ assert.Equal(t, nil, err)
+ assert.Equal(t, len(expectedRead)+1, int(pos))
+
+ // Seek back to start to get tests in sync
+ pos, err = h.Seek(0, io.SeekStart)
+ assert.NoError(t, err)
+ assert.Equal(t, int64(0), pos)
+ }
+
+ // Should not allow seek to negative position start
+ pos, err = h.Seek(-1, io.SeekCurrent)
+ assert.Equal(t, errNegativeSeek, err)
+ assert.Equal(t, 0, int(pos))
+
+ // Should not allow seek with invalid whence
+ pos, err = h.Seek(0, 3)
+ assert.Equal(t, errInvalidWhence, err)
+ assert.Equal(t, 0, int(pos))
+
+ // check read
+ dst := make([]byte, 5)
+ n, err := h.Read(dst)
+ assert.Nil(t, err)
+ assert.Equal(t, 5, n)
+ assert.Equal(t, expectedRead[:5], dst)
+
+ // Test io.SeekCurrent
+ pos, err = h.Seek(-3, io.SeekCurrent)
+ assert.Nil(t, err)
+ assert.Equal(t, 2, int(pos))
+
+ // check read
+ setWantStart(src, 2)
+ n, err = h.Read(dst)
+ assert.Nil(t, err)
+ assert.Equal(t, 5, n)
+ assert.Equal(t, expectedRead[2:7], dst)
+
+ pos, err = h.Seek(-2, io.SeekCurrent)
+ assert.Nil(t, err)
+ assert.Equal(t, 5, int(pos))
+
+ // Test io.SeekEnd
+ pos, err = h.Seek(-3, io.SeekEnd)
+ if !unknownSize {
+ assert.Nil(t, err)
+ assert.Equal(t, len(expectedRead)-3, int(pos))
+ } else {
+ assert.Equal(t, errBadEndSeek, err)
+ assert.Equal(t, 0, int(pos))
+
+ // sync
+ pos, err = h.Seek(1, io.SeekCurrent)
+ assert.Nil(t, err)
+ assert.Equal(t, 6, int(pos))
+ }
+
+ // check read
+ dst = make([]byte, 3)
+ setWantStart(src, int64(len(expectedRead)-3))
+ n, err = h.Read(dst)
+ assert.Nil(t, err)
+ assert.Equal(t, 3, n)
+ assert.Equal(t, expectedRead[len(expectedRead)-3:], dst)
+
+ // check close
+ assert.NoError(t, h.Close())
+ _, err = h.Seek(0, io.SeekCurrent)
+ assert.Equal(t, errFileClosed, err)
+ })
+
+ t.Run("AccountRead", func(t *testing.T) {
+ h, _, err := testReOpen(nil, 10)
+ assert.NoError(t, err)
+
+ var total int
+ h.SetAccounting(func(n int) error {
+ total += n
+ return nil
+ })
+
+ dst := make([]byte, 3)
+ n, err := h.Read(dst)
+ assert.Equal(t, 3, n)
+ assert.NoError(t, err)
+ assert.Equal(t, 3, total)
+ })
+
+ t.Run("AccountReadDelay", func(t *testing.T) {
+ h, _, err := testReOpen(nil, 10)
+ assert.NoError(t, err)
+
+ var total int
+ h.SetAccounting(func(n int) error {
+ total += n
+ return nil
+ })
+
+ rewind := func() {
+ _, err := h.Seek(0, io.SeekStart)
+ require.NoError(t, err)
+ }
+
+ h.DelayAccounting(3)
+
+ dst := make([]byte, 16)
+
+ n, err := h.Read(dst)
+ assert.Equal(t, len(expectedRead), n)
+ assert.Equal(t, io.EOF, err)
+ assert.Equal(t, 0, total)
+ rewind()
+
+ n, err = h.Read(dst)
+ assert.Equal(t, len(expectedRead), n)
+ assert.Equal(t, io.EOF, err)
+ assert.Equal(t, 0, total)
+ rewind()
+
+ n, err = h.Read(dst)
+ assert.Equal(t, len(expectedRead), n)
+ assert.Equal(t, io.EOF, err)
+ assert.Equal(t, len(expectedRead), total)
+ rewind()
+
+ n, err = h.Read(dst)
+ assert.Equal(t, len(expectedRead), n)
+ assert.Equal(t, io.EOF, err)
+ assert.Equal(t, 2*len(expectedRead), total)
+ rewind()
+ })
+
+ t.Run("AccountReadError", func(t *testing.T) {
+ // Test accounting errors
+ h, _, err := testReOpen(nil, 10)
+ assert.NoError(t, err)
+
+ h.SetAccounting(func(n int) error {
+ return errorTestError
+ })
+
+ dst := make([]byte, 3)
+ n, err := h.Read(dst)
+ assert.Equal(t, 3, n)
+ assert.Equal(t, errorTestError, err)
+ })
+ })
+ }
+}