github.com/Files-com/files-sdk-go/v2@v2.1.2/file/downloadparts.go (about)

     1  package file
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/Files-com/files-sdk-go/v2/file/manager"
    15  	"github.com/Files-com/files-sdk-go/v2/lib"
    16  	"github.com/panjf2000/ants/v2"
    17  
    18  	"github.com/samber/lo"
    19  
    20  	files_sdk "github.com/Files-com/files-sdk-go/v2"
    21  )
    22  
    23  const (
    24  	DownloadPartChunkSize = int64(1024 * 1024 * 5)
    25  	DownloadPartLimit     = 15
    26  )
    27  
    28  type DownloadParts struct {
    29  	globalWait manager.ConcurrencyManager
    30  	context.CancelFunc
    31  	context.Context
    32  	queueCancel  context.CancelFunc
    33  	queueContext context.Context
    34  	fs.File
    35  	fs.FileInfo
    36  	lib.WriterAndAt
    37  	totalWritten  int64
    38  	parts         []*Part
    39  	queue         chan *Part
    40  	finishedParts chan *Part
    41  	CloseError    error
    42  	files_sdk.Config
    43  	fileManager *ants.Pool
    44  	*sync.RWMutex
    45  	queueLock      *sync.Mutex
    46  	partsCompleted uint32
    47  	path           string
    48  }
    49  
    50  func (d *DownloadParts) Init(file fs.File, info fs.FileInfo, globalWait manager.ConcurrencyManager, writer lib.WriterAndAt, config files_sdk.Config) *DownloadParts {
    51  	d.File = file
    52  	d.FileInfo = info
    53  	d.path = info.Name()
    54  	d.globalWait = globalWait
    55  	d.WriterAndAt = writer
    56  	d.Config = config
    57  	d.RWMutex = &sync.RWMutex{}
    58  	d.queueLock = &sync.Mutex{}
    59  	return d
    60  }
    61  
    62  func (d *DownloadParts) Run(ctx context.Context) error {
    63  	d.Context, d.CancelFunc = context.WithCancel(ctx)
    64  	d.queueContext, d.queueCancel = context.WithCancel(d.Context)
    65  	defer func() {
    66  		d.Config.LogPath(
    67  			d.path,
    68  			map[string]interface{}{
    69  				"message":  "Finished canceling context and closing file",
    70  				"realSize": atomic.LoadInt64(&d.totalWritten),
    71  			},
    72  		)
    73  		d.CancelFunc()
    74  		d.CloseError = d.WriterAndAt.Close()
    75  		d.fileManager.Release()
    76  	}()
    77  	var err error
    78  	d.fileManager, err = ants.NewPool(lo.Min[int](append([]int{}, DownloadPartLimit, d.globalWait.Max())))
    79  	if err != nil {
    80  		return err
    81  	}
    82  	if d.downloadFileCutOff() {
    83  		return d.downloadFile()
    84  	} else {
    85  		d.buildParts()
    86  		d.listenOnQueue()
    87  		d.addPartsToQueue()
    88  		return d.waitForParts()
    89  	}
    90  }
    91  
    92  func (d *DownloadParts) downloadFileCutOff() bool {
    93  	// Don't break up file if running part serially.
    94  	if d.fileManager.Cap() == 1 || d.globalWait.DownloadFilesAsSingleStream {
    95  		return true
    96  	}
    97  
    98  	return d.FileInfo.Size() <= DownloadPartChunkSize*2
    99  }
   100  
   101  func (d *DownloadParts) FinalSize() int64 {
   102  	return atomic.LoadInt64(&d.totalWritten)
   103  }
   104  
   105  func (d *DownloadParts) waitForParts() error {
   106  	var err error
   107  	for i := range d.parts {
   108  		part := <-d.finishedParts
   109  		if part.Err() != nil && !errors.Is(part.Err(), context.Canceled) {
   110  			err = part.Err()
   111  		}
   112  		atomic.AddUint32(&d.partsCompleted, 1)
   113  		d.Config.LogPath(
   114  			d.path,
   115  			map[string]interface{}{
   116  				"RunningParts":  d.fileManager.Running(),
   117  				"limit":         d.fileManager.Cap(),
   118  				"parts":         len(d.parts),
   119  				"Written":       atomic.LoadInt64(&d.totalWritten),
   120  				"PartFinished":  part.number,
   121  				"partBytes":     part.bytes,
   122  				"PartsFinished": i + 1,
   123  				"error":         part.Err(),
   124  			},
   125  		)
   126  	}
   127  	close(d.queue)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	return d.realSizeOverLap()
   132  }
   133  
   134  func (d *DownloadParts) realSizeOverLap() error {
   135  	lastPart := d.parts[len(d.parts)-1]
   136  	d.Config.LogPath(
   137  		d.path,
   138  		map[string]interface{}{
   139  			"message":  "starting realSizeOverLap",
   140  			"size":     d.FileInfo.Size(),
   141  			"realSize": atomic.LoadInt64(&d.totalWritten),
   142  		},
   143  	)
   144  	defer func() {
   145  		d.Config.LogPath(
   146  			d.path,
   147  			map[string]interface{}{
   148  				"message":  "finishing realSizeOverLap",
   149  				"size":     d.FileInfo.Size(),
   150  				"realSize": atomic.LoadInt64(&d.totalWritten),
   151  			},
   152  		)
   153  	}()
   154  	for {
   155  		if d.FileInfo.(UntrustedSize).UntrustedSize() && d.queueContext.Err() == nil && lastPart.bytes == lastPart.len {
   156  			d.queueLock.Lock()
   157  			d.queue = make(chan *Part, 1)
   158  			d.queueLock.Unlock()
   159  			d.finishedParts = make(chan *Part, 1)
   160  			nextPart := &Part{number: lastPart.number + 1, OffSet: OffSet{off: lastPart.off + lastPart.bytes, len: DownloadPartChunkSize}}
   161  			d.Config.LogPath(d.path, map[string]interface{}{"message": "Next Part for size guess", "part": nextPart.number})
   162  			d.parts = append(d.parts, nextPart)
   163  
   164  			go d.processPart(nextPart.Start(d.Context), true)
   165  			select {
   166  			case lastPart = <-d.finishedParts:
   167  				if lastPart.error != nil {
   168  					if lastPart.error == io.EOF || errors.Is(lastPart.error, UntrustedSizeRangeRequestSizeSentReceived) {
   169  						return nil
   170  					}
   171  					return lastPart.error
   172  				}
   173  			case lastPart = <-d.queue:
   174  				if lastPart.error != nil {
   175  					if lastPart.error == io.EOF {
   176  						return nil
   177  					}
   178  					return lastPart.error
   179  				}
   180  			}
   181  		} else {
   182  			if d.FileInfo.Size() != atomic.LoadInt64(&d.totalWritten) && !d.FileInfo.(UntrustedSize).UntrustedSize() {
   183  				return fmt.Errorf("server reported size does not match downloaded file. - expected: %v, actual: %v", d.FileInfo.Size(), atomic.LoadInt64(&d.totalWritten))
   184  			}
   185  			return nil
   186  		}
   187  	}
   188  }
   189  
   190  func (d *DownloadParts) addPartsToQueue() {
   191  	for _, part := range d.parts {
   192  		d.queue <- part
   193  	}
   194  }
   195  
   196  func (d *DownloadParts) listenOnQueue() {
   197  	go func() {
   198  		d.queueLock.Lock()
   199  		defer d.queueLock.Unlock()
   200  		for {
   201  			select {
   202  			case part := <-d.queue:
   203  				if part == nil {
   204  					return
   205  				}
   206  				if d.queueContext.Err() != nil {
   207  					d.finishedParts <- part
   208  					continue
   209  				}
   210  				if part.processing {
   211  					panic(part)
   212  				}
   213  				if len(part.requests) > 3 {
   214  					d.Config.LogPath(d.path, map[string]interface{}{"message": "Maxed out reties", "part": part.number})
   215  					d.finishedParts <- part
   216  				} else {
   217  					if part.Context.Err() != nil {
   218  						d.finishedParts <- part.Done()
   219  						continue
   220  					}
   221  					part.Clear()
   222  					d.globalWait.Wait()
   223  					d.fileManager.Submit(func() {
   224  						d.stateLog()
   225  						d.processPart(part.Start(), false)
   226  						d.globalWait.Done()
   227  					})
   228  
   229  					d.slowDownTellFirstPart(part)
   230  				}
   231  			}
   232  		}
   233  	}()
   234  }
   235  
   236  func (d *DownloadParts) slowDownTellFirstPart(part *Part) {
   237  	// One request needs to return the header for MaxConnections.
   238  	// Once there finishedParts can be tuned to that value. So slow down to give time to get that value.
   239  	if atomic.LoadUint32(&d.partsCompleted) != 0 || d.parts[0].Err() != nil {
   240  		return
   241  	}
   242  	startTime := time.Now()
   243  	timeout := startTime.Add(time.Duration((part.number)*250) * time.Millisecond)
   244  	ctx, cancel := context.WithDeadline(d.Context, timeout)
   245  	defer cancel()
   246  
   247  	for {
   248  		select {
   249  		case <-ctx.Done():
   250  			d.Config.LogPath(d.path, map[string]interface{}{"message": fmt.Sprintf("Part1 to Finish: stopped waited %v after part %v", time.Now().Sub(startTime).Truncate(time.Microsecond), part.number)})
   251  			return
   252  		default:
   253  			if atomic.LoadUint32(&d.partsCompleted) != 0 || d.parts[0].Err() != nil {
   254  				d.Config.LogPath(d.path, map[string]interface{}{"message": fmt.Sprintf("Part1 to Finish: finish after waiting %v after part %v", time.Now().Sub(startTime).Truncate(time.Microsecond), part.number)})
   255  				cancel()
   256  				return
   257  			}
   258  		}
   259  	}
   260  }
   261  
   262  func (d *DownloadParts) stateLog(extraState ...map[string]interface{}) {
   263  	d.Config.LogPath(
   264  		d.path,
   265  		lo.Assign[string, interface{}](append(extraState, d.state())...),
   266  	)
   267  }
   268  
   269  func (d *DownloadParts) state() map[string]interface{} {
   270  	return map[string]interface{}{
   271  		"RunningParts": d.fileManager.Running(),
   272  		"limit":        d.fileManager.Cap(),
   273  		"parts":        len(d.parts),
   274  		"written":      atomic.LoadInt64(&d.totalWritten),
   275  		"completed":    atomic.LoadUint32(&d.partsCompleted),
   276  	}
   277  }
   278  
   279  func (d *DownloadParts) buildParts() {
   280  	size := d.FileInfo.Size()
   281  	iter := (ByteOffset{PartSizes: lib.PartSizes}).BySize(&size)
   282  
   283  	for {
   284  		offset, next, i := iter()
   285  		d.parts = append(d.parts, (&Part{OffSet: offset, number: i + 1}).WithContext(d.Context))
   286  		if next == nil {
   287  			break
   288  		}
   289  		iter = next
   290  	}
   291  
   292  	d.finishedParts = make(chan *Part, len(d.parts))
   293  	d.queue = make(chan *Part, len(d.parts))
   294  	d.stateLog()
   295  }
   296  
   297  func (d *DownloadParts) processPart(part *Part, UnexpectedEOF bool) {
   298  	d.processRanger(part, d.File.(ReaderRange), UnexpectedEOF)
   299  }
   300  
   301  func (d *DownloadParts) processRanger(part *Part, ranger ReaderRange, UnexpectedEOF bool) {
   302  	withContext, ok := ranger.(lib.FileWithContext)
   303  	if ok {
   304  		partCtx, partCancel := context.WithCancel(part.Context)
   305  		defer partCancel()
   306  		ranger = withContext.WithContext(partCtx).(ReaderRange)
   307  	}
   308  	r, err := ranger.ReaderRange(part.off, part.len+part.off-1)
   309  	if d.requeueOnError(part, err, UnexpectedEOF) {
   310  		return
   311  	}
   312  	if f, ok := ranger.(*File); ok {
   313  		if f.MaxConnections != 0 && d.fileManager.Cap() > f.MaxConnections {
   314  			d.fileManager.Tune(f.MaxConnections)
   315  			d.stateLog(map[string]interface{}{"message": "tuning pool", "cap": d.fileManager.Cap()})
   316  		}
   317  	}
   318  	info, _ := ranger.Stat()
   319  	sizeTrustInfo, ok := info.(UntrustedSize)
   320  	if ok && sizeTrustInfo.SizeTrust() != NullSizeTrust {
   321  		d.RWMutex.Lock()
   322  		d.FileInfo = sizeTrustInfo
   323  		d.RWMutex.Unlock()
   324  	}
   325  
   326  	wn, err := lib.CopyAt(d.WriterAndAt, part.off, r)
   327  	part.bytes = wn
   328  
   329  	part.SetError(r.Close())
   330  	if sizeTrustInfo.UntrustedSize() && part.Err() != nil {
   331  		d.verifySizeAndUpdateParts(part)
   332  	}
   333  
   334  	if d.requeueOnError(part, err, UnexpectedEOF) {
   335  		return
   336  	}
   337  
   338  	atomic.AddInt64(&d.totalWritten, wn)
   339  	d.finishedParts <- part.Done()
   340  }
   341  
   342  func (d *DownloadParts) verifySizeAndUpdateParts(part *Part) {
   343  	if errors.Is(part.Err(), UntrustedSizeRangeRequestSizeSentLessThanExpected) {
   344  		d.Config.LogPath(
   345  			d.path,
   346  			map[string]interface{}{"error": part.Err(), "part": part.number},
   347  		)
   348  		d.queueCancel()
   349  		// cancelAll greater parts
   350  		for _, p := range d.parts[part.number:] {
   351  			d.Config.LogPath(
   352  				d.path,
   353  				map[string]interface{}{"message": "canceling invalid part", "part": p.number},
   354  			)
   355  			p.CancelFunc()
   356  		}
   357  		part.SetError(nil)
   358  	}
   359  }
   360  
   361  func (d *DownloadParts) requeueOnError(part *Part, err error, UnexpectedEOF bool) bool {
   362  	for _, err := range []error{err, part.Err()} {
   363  		if err != nil && !errors.Is(err, io.EOF) {
   364  			if strings.Contains(err.Error(), "stream error") {
   365  				return false
   366  			}
   367  			if UnexpectedEOF && errors.Is(err, io.ErrUnexpectedEOF) {
   368  				return false
   369  			}
   370  			if part.error == nil {
   371  				part.SetError(err)
   372  			}
   373  			d.Config.LogPath(
   374  				d.path,
   375  				map[string]interface{}{"message": "requeuing", "error": part.Err(), "part": part.number},
   376  			)
   377  			progressWriter, ok := d.WriterAndAt.(lib.ProgressWriter)
   378  			if ok {
   379  				progressWriter.ProgressWatcher(-part.bytes)
   380  			}
   381  			d.queue <- part.Done() // either timeout or stream error try part again.
   382  			return true
   383  		}
   384  	}
   385  
   386  	return false
   387  }
   388  
   389  func (d *DownloadParts) downloadFile() error {
   390  	withContext, ok := d.File.(lib.FileWithContext)
   391  	if ok {
   392  		d.File = withContext.WithContext(d.Context)
   393  	}
   394  	n, err := io.Copy(d.WriterAndAt, d.File)
   395  	atomic.AddInt64(&d.totalWritten, n)
   396  	if err != nil {
   397  		return err
   398  	}
   399  	err = d.File.Close()
   400  	if err != nil {
   401  		return err
   402  	}
   403  
   404  	info, _ := d.File.Stat()
   405  	sizeTrustInfo, ok := info.(UntrustedSize)
   406  	if ok && sizeTrustInfo.SizeTrust() != NullSizeTrust {
   407  		d.FileInfo = sizeTrustInfo
   408  	}
   409  
   410  	if d.FileInfo.Size() != atomic.LoadInt64(&d.totalWritten) {
   411  		return fmt.Errorf("server reported size does not match downloaded file. - expected: %v, actual: %v", d.FileInfo.Size(), atomic.LoadInt64(&d.totalWritten))
   412  	}
   413  	return nil
   414  }