diff options
Diffstat (limited to 'fs/operations/reopen_test.go')
| -rw-r--r-- | fs/operations/reopen_test.go | 429 |
1 files changed, 429 insertions, 0 deletions
diff --git a/fs/operations/reopen_test.go b/fs/operations/reopen_test.go new file mode 100644 index 0000000..19f6ca3 --- /dev/null +++ b/fs/operations/reopen_test.go @@ -0,0 +1,429 @@ +package operations + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/rclone/rclone/fs" + "github.com/rclone/rclone/fs/hash" + "github.com/rclone/rclone/fstest/mockobject" + "github.com/rclone/rclone/lib/pool" + "github.com/rclone/rclone/lib/readers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// check interfaces +var ( + _ io.ReadSeekCloser = (*ReOpen)(nil) + _ io.ReaderAt = (*ReOpen)(nil) + _ pool.DelayAccountinger = (*ReOpen)(nil) +) + +var errorTestError = errors.New("test error") + +// this is a wrapper for a mockobject with a custom Open function +// +// breaks indicate the number of bytes to read before returning an +// error +type reOpenTestObject struct { + fs.Object + t *testing.T + wantStart int64 + breaks []int64 + unknownSize bool +} + +// Open opens the file for read. Call Close() on the returned io.ReadCloser +// +// This will break after reading the number of bytes in breaks +func (o *reOpenTestObject) Open(ctx context.Context, options ...fs.OpenOption) (io.ReadCloser, error) { + // Lots of backends do this - make sure it works as it modifies options + fs.FixRangeOption(options, o.Size()) + gotHash := false + gotRange := false + startPos := int64(0) + for _, option := range options { + switch x := option.(type) { + case *fs.HashesOption: + gotHash = true + case *fs.RangeOption: + gotRange = true + startPos = x.Start + if o.unknownSize { + assert.Equal(o.t, int64(-1), x.End) + } + case *fs.SeekOption: + startPos = x.Offset + } + } + assert.Equal(o.t, o.wantStart, startPos) + // Check if ranging, mustn't have hash if offset != 0 + if gotHash && gotRange { + assert.Equal(o.t, int64(0), startPos) + } + rc, err := o.Object.Open(ctx, options...) + if err != nil { + return nil, err + } + if len(o.breaks) > 0 { + // Pop a breakpoint off + N := o.breaks[0] + o.breaks = o.breaks[1:] + o.wantStart += N + // If 0 then return an error immediately + if N == 0 { + return nil, errorTestError + } + // Read N bytes then an error + r := io.MultiReader(&io.LimitedReader{R: rc, N: N}, readers.ErrorReader{Err: errorTestError}) + // Wrap with Close in a new readCloser + rc = readCloser{Reader: r, Closer: rc} + } + return rc, nil +} + +func TestReOpen(t *testing.T) { + for _, testName := range []string{"Normal", "WithRangeOption", "WithSeekOption", "UnknownSize"} { + t.Run(testName, func(t *testing.T) { + // Contents for the mock object + var ( + reOpenTestcontents = []byte("0123456789") + expectedRead = reOpenTestcontents + rangeOption *fs.RangeOption + seekOption *fs.SeekOption + unknownSize = false + ) + switch testName { + case "Normal": + case "WithRangeOption": + rangeOption = &fs.RangeOption{Start: 1, End: 7} // range is inclusive + expectedRead = reOpenTestcontents[1:8] + case "WithSeekOption": + seekOption = &fs.SeekOption{Offset: 2} + expectedRead = reOpenTestcontents[2:] + case "UnknownSize": + rangeOption = &fs.RangeOption{Start: 1, End: -1} + expectedRead = reOpenTestcontents[1:] + unknownSize = true + default: + panic("bad test name") + } + + // Start the test with the given breaks + testReOpen := func(breaks []int64, maxRetries int) (*ReOpen, *reOpenTestObject, error) { + srcOrig := mockobject.New("potato").WithContent(reOpenTestcontents, mockobject.SeekModeNone) + srcOrig.SetUnknownSize(unknownSize) + src := &reOpenTestObject{ + Object: srcOrig, + t: t, + breaks: breaks, + unknownSize: unknownSize, + } + opts := []fs.OpenOption{} + if rangeOption == nil && seekOption == nil { + opts = append(opts, &fs.HashesOption{Hashes: hash.NewHashSet(hash.MD5)}) + } + if rangeOption != nil { + opts = append(opts, rangeOption) + src.wantStart = rangeOption.Start + } + if seekOption != nil { + opts = append(opts, seekOption) + src.wantStart = seekOption.Offset + } + rc, err := NewReOpen(context.Background(), src, maxRetries, opts...) + return rc, src, err + } + + // Reset the start after a seek, taking into account the offset + setWantStart := func(src *reOpenTestObject, x int64) { + src.wantStart = x + if rangeOption != nil { + src.wantStart += rangeOption.Start + } else if seekOption != nil { + src.wantStart += seekOption.Offset + } + } + + t.Run("Basics", func(t *testing.T) { + // open + h, _, err := testReOpen(nil, 10) + assert.NoError(t, err) + + // Check contents read correctly + got, err := io.ReadAll(h) + assert.NoError(t, err) + assert.Equal(t, expectedRead, got) + + // Check read after end + var buf = make([]byte, 1) + n, err := h.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) + + // Rewind the stream + _, err = h.Seek(0, io.SeekStart) + require.NoError(t, err) + + // Check contents read correctly + got, err = io.ReadAll(h) + assert.NoError(t, err) + assert.Equal(t, expectedRead, got) + + // Check close + assert.NoError(t, h.Close()) + + // Check double close + assert.Equal(t, errFileClosed, h.Close()) + + // Check read after close + n, err = h.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, errFileClosed, err) + }) + + t.Run("ErrorAtStart", func(t *testing.T) { + // open with immediate breaking + h, _, err := testReOpen([]int64{0}, 10) + assert.Equal(t, errorTestError, err) + assert.Nil(t, h) + }) + + t.Run("WithErrors", func(t *testing.T) { + // open with a few break points but less than the max + h, _, err := testReOpen([]int64{2, 1, 3}, 10) + assert.NoError(t, err) + + // check contents + got, err := io.ReadAll(h) + assert.NoError(t, err) + assert.Equal(t, expectedRead, got) + + // check close + assert.NoError(t, h.Close()) + }) + + t.Run("TooManyErrors", func(t *testing.T) { + // open with a few break points but >= the max + h, _, err := testReOpen([]int64{2, 1, 3}, 3) + assert.NoError(t, err) + + // check contents + got, err := io.ReadAll(h) + assert.Equal(t, errorTestError, err) + assert.Equal(t, expectedRead[:6], got) + + // check old error is returned + var buf = make([]byte, 1) + n, err := h.Read(buf) + assert.Equal(t, 0, n) + assert.Equal(t, errTooManyTries, err) + + // Check close + assert.Equal(t, errFileClosed, h.Close()) + }) + + t.Run("ReadAt", func(t *testing.T) { + // open + h, src, err := testReOpen([]int64{2, 1, 3}, 10) + assert.NoError(t, err) + + buf := make([]byte, 5) + + // Read at 0 + n, err := h.ReadAt(buf, 0) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, expectedRead[:n], buf[:n]) + + // Read at 1 + setWantStart(src, 1) + n, err = h.ReadAt(buf[:3], 1) + require.NoError(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, expectedRead[1:n+1], buf[:n]) + + // check position unchanged + pos, err := h.Seek(0, io.SeekCurrent) + require.NoError(t, err) + assert.Equal(t, int64(0), pos) + + // check close + assert.NoError(t, h.Close()) + _, err = h.Seek(0, io.SeekCurrent) + assert.Equal(t, errFileClosed, err) + }) + + t.Run("Seek", func(t *testing.T) { + // open + h, src, err := testReOpen([]int64{2, 1, 3}, 10) + assert.NoError(t, err) + + // Seek to end + pos, err := h.Seek(int64(len(expectedRead)), io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(len(expectedRead)), pos) + + // Seek to start + pos, err = h.Seek(0, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(0), pos) + + // Should not allow seek past end + pos, err = h.Seek(int64(len(expectedRead))+1, io.SeekCurrent) + if !unknownSize { + assert.Equal(t, errSeekPastEnd, err) + assert.Equal(t, len(expectedRead), int(pos)) + } else { + assert.Equal(t, nil, err) + assert.Equal(t, len(expectedRead)+1, int(pos)) + + // Seek back to start to get tests in sync + pos, err = h.Seek(0, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(0), pos) + } + + // Should not allow seek to negative position start + pos, err = h.Seek(-1, io.SeekCurrent) + assert.Equal(t, errNegativeSeek, err) + assert.Equal(t, 0, int(pos)) + + // Should not allow seek with invalid whence + pos, err = h.Seek(0, 3) + assert.Equal(t, errInvalidWhence, err) + assert.Equal(t, 0, int(pos)) + + // check read + dst := make([]byte, 5) + n, err := h.Read(dst) + assert.Nil(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, expectedRead[:5], dst) + + // Test io.SeekCurrent + pos, err = h.Seek(-3, io.SeekCurrent) + assert.Nil(t, err) + assert.Equal(t, 2, int(pos)) + + // check read + setWantStart(src, 2) + n, err = h.Read(dst) + assert.Nil(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, expectedRead[2:7], dst) + + pos, err = h.Seek(-2, io.SeekCurrent) + assert.Nil(t, err) + assert.Equal(t, 5, int(pos)) + + // Test io.SeekEnd + pos, err = h.Seek(-3, io.SeekEnd) + if !unknownSize { + assert.Nil(t, err) + assert.Equal(t, len(expectedRead)-3, int(pos)) + } else { + assert.Equal(t, errBadEndSeek, err) + assert.Equal(t, 0, int(pos)) + + // sync + pos, err = h.Seek(1, io.SeekCurrent) + assert.Nil(t, err) + assert.Equal(t, 6, int(pos)) + } + + // check read + dst = make([]byte, 3) + setWantStart(src, int64(len(expectedRead)-3)) + n, err = h.Read(dst) + assert.Nil(t, err) + assert.Equal(t, 3, n) + assert.Equal(t, expectedRead[len(expectedRead)-3:], dst) + + // check close + assert.NoError(t, h.Close()) + _, err = h.Seek(0, io.SeekCurrent) + assert.Equal(t, errFileClosed, err) + }) + + t.Run("AccountRead", func(t *testing.T) { + h, _, err := testReOpen(nil, 10) + assert.NoError(t, err) + + var total int + h.SetAccounting(func(n int) error { + total += n + return nil + }) + + dst := make([]byte, 3) + n, err := h.Read(dst) + assert.Equal(t, 3, n) + assert.NoError(t, err) + assert.Equal(t, 3, total) + }) + + t.Run("AccountReadDelay", func(t *testing.T) { + h, _, err := testReOpen(nil, 10) + assert.NoError(t, err) + + var total int + h.SetAccounting(func(n int) error { + total += n + return nil + }) + + rewind := func() { + _, err := h.Seek(0, io.SeekStart) + require.NoError(t, err) + } + + h.DelayAccounting(3) + + dst := make([]byte, 16) + + n, err := h.Read(dst) + assert.Equal(t, len(expectedRead), n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, total) + rewind() + + n, err = h.Read(dst) + assert.Equal(t, len(expectedRead), n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, total) + rewind() + + n, err = h.Read(dst) + assert.Equal(t, len(expectedRead), n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, len(expectedRead), total) + rewind() + + n, err = h.Read(dst) + assert.Equal(t, len(expectedRead), n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 2*len(expectedRead), total) + rewind() + }) + + t.Run("AccountReadError", func(t *testing.T) { + // Test accounting errors + h, _, err := testReOpen(nil, 10) + assert.NoError(t, err) + + h.SetAccounting(func(n int) error { + return errorTestError + }) + + dst := make([]byte, 3) + n, err := h.Read(dst) + assert.Equal(t, 3, n) + assert.Equal(t, errorTestError, err) + }) + }) + } +} |
