github.com/Files-com/files-sdk-go/v3@v3.1.81/file/downloadparts.go (about)

     1  package file
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	files_sdk "github.com/Files-com/files-sdk-go/v3"
    15  	"github.com/Files-com/files-sdk-go/v3/file/manager"
    16  	"github.com/Files-com/files-sdk-go/v3/lib"
    17  	"github.com/panjf2000/ants/v2"
    18  	"github.com/samber/lo"
    19  )
    20  
    21  func init() {
    22  	ants.Release()
    23  }
    24  
    25  const (
    26  	DownloadPartChunkSize = int64(1024 * 1024 * 5)
    27  	DownloadPartLimit     = 15
    28  )
    29  
    30  type DownloadParts struct {
    31  	globalWait manager.ConcurrencyManager
    32  	context.CancelFunc
    33  	context.Context
    34  	queueCancel  context.CancelFunc
    35  	queueContext context.Context
    36  	fs.File
    37  	fs.FileInfo
    38  	lib.WriterAndAt
    39  	totalWritten  int64
    40  	parts         []*Part
    41  	queue         chan *Part
    42  	finishedParts chan *Part
    43  	CloseError    error
    44  	files_sdk.Config
    45  	fileManager *ants.Pool
    46  	*sync.RWMutex
    47  	queueLock      *sync.Mutex
    48  	partsCompleted uint32
    49  	path           string
    50  }
    51  
    52  func (d *DownloadParts) Init(file fs.File, info fs.FileInfo, globalWait manager.ConcurrencyManager, writer lib.WriterAndAt, config files_sdk.Config) *DownloadParts {
    53  	d.File = file
    54  	d.FileInfo = info
    55  	d.path = info.Name()
    56  	d.globalWait = globalWait
    57  	d.WriterAndAt = writer
    58  	d.Config = config
    59  	d.RWMutex = &sync.RWMutex{}
    60  	d.queueLock = &sync.Mutex{}
    61  	return d
    62  }
    63  
    64  func (d *DownloadParts) Run(ctx context.Context) error {
    65  	d.Context, d.CancelFunc = context.WithCancel(ctx)
    66  	d.queueContext, d.queueCancel = context.WithCancel(d.Context)
    67  	defer func() {
    68  		d.Config.LogPath(
    69  			d.path,
    70  			map[string]interface{}{
    71  				"message":  "Finished canceling context and closing file",
    72  				"realSize": atomic.LoadInt64(&d.totalWritten),
    73  			},
    74  		)
    75  		d.CancelFunc()
    76  		d.CloseError = d.WriterAndAt.Close()
    77  		d.fileManager.Release()
    78  	}()
    79  	var err error
    80  	d.fileManager, err = ants.NewPool(lo.Min[int](append([]int{}, DownloadPartLimit, d.globalWait.Max())))
    81  	if err != nil {
    82  		return err
    83  	}
    84  	if d.downloadFileCutOff() {
    85  		return d.downloadFile()
    86  	} else {
    87  		d.buildParts()
    88  		d.listenOnQueue()
    89  		d.addPartsToQueue()
    90  		return d.waitForParts()
    91  	}
    92  }
    93  
    94  func (d *DownloadParts) downloadFileCutOff() bool {
    95  	// Don't break up file if running part serially.
    96  	if d.fileManager.Cap() == 1 || d.globalWait.DownloadFilesAsSingleStream {
    97  		return true
    98  	}
    99  
   100  	return d.FileInfo.Size() <= DownloadPartChunkSize*2
   101  }
   102  
   103  func (d *DownloadParts) FinalSize() int64 {
   104  	return atomic.LoadInt64(&d.totalWritten)
   105  }
   106  
   107  func (d *DownloadParts) waitForParts() error {
   108  	var err error
   109  	for i := range d.parts {
   110  		part := <-d.finishedParts
   111  		if part.Err() != nil && !errors.Is(part.Err(), context.Canceled) {
   112  			err = part.Err()
   113  		}
   114  		atomic.AddUint32(&d.partsCompleted, 1)
   115  		d.Config.LogPath(
   116  			d.path,
   117  			map[string]interface{}{
   118  				"RunningParts":  d.fileManager.Running(),
   119  				"limit":         d.fileManager.Cap(),
   120  				"parts":         len(d.parts),
   121  				"Written":       atomic.LoadInt64(&d.totalWritten),
   122  				"PartFinished":  part.number,
   123  				"partBytes":     part.bytes,
   124  				"PartsFinished": i + 1,
   125  				"error":         part.Err(),
   126  			},
   127  		)
   128  	}
   129  	close(d.queue)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	return d.realSizeOverLap()
   134  }
   135  
   136  func (d *DownloadParts) realSizeOverLap() error {
   137  	lastPart := d.parts[len(d.parts)-1]
   138  	d.Config.LogPath(
   139  		d.path,
   140  		map[string]interface{}{
   141  			"message":  "starting realSizeOverLap",
   142  			"size":     d.FileInfo.Size(),
   143  			"realSize": atomic.LoadInt64(&d.totalWritten),
   144  		},
   145  	)
   146  	defer func() {
   147  		d.Config.LogPath(
   148  			d.path,
   149  			map[string]interface{}{
   150  				"message":  "finishing realSizeOverLap",
   151  				"size":     d.FileInfo.Size(),
   152  				"realSize": atomic.LoadInt64(&d.totalWritten),
   153  			},
   154  		)
   155  	}()
   156  	for {
   157  		if d.FileInfo.(UntrustedSize).UntrustedSize() && d.queueContext.Err() == nil && lastPart.bytes == lastPart.len {
   158  			d.queueLock.Lock()
   159  			d.queue = make(chan *Part, 1)
   160  			d.queueLock.Unlock()
   161  			d.finishedParts = make(chan *Part, 1)
   162  			nextPart := &Part{number: lastPart.number + 1, OffSet: OffSet{off: lastPart.off + lastPart.bytes, len: DownloadPartChunkSize}}
   163  			d.Config.LogPath(d.path, map[string]interface{}{"message": "Next Part for size guess", "part": nextPart.number})
   164  			d.parts = append(d.parts, nextPart)
   165  
   166  			go d.processPart(nextPart.Start(d.Context), true)
   167  			select {
   168  			case lastPart = <-d.finishedParts:
   169  				if lastPart.error != nil {
   170  					if lastPart.error == io.EOF || errors.Is(lastPart.error, UntrustedSizeRangeRequestSizeSentReceived) {
   171  						return nil
   172  					}
   173  					return lastPart.error
   174  				}
   175  			case lastPart = <-d.queue:
   176  				if lastPart.error != nil {
   177  					if lastPart.error == io.EOF {
   178  						return nil
   179  					}
   180  					return lastPart.error
   181  				}
   182  			}
   183  		} else {
   184  			if d.FileInfo.Size() != atomic.LoadInt64(&d.totalWritten) && !d.FileInfo.(UntrustedSize).UntrustedSize() {
   185  				return fmt.Errorf("server reported size does not match downloaded file. - expected: %v, actual: %v", d.FileInfo.Size(), atomic.LoadInt64(&d.totalWritten))
   186  			}
   187  			return nil
   188  		}
   189  	}
   190  }
   191  
   192  func (d *DownloadParts) addPartsToQueue() {
   193  	for _, part := range d.parts {
   194  		d.queue <- part
   195  	}
   196  }
   197  
   198  func (d *DownloadParts) listenOnQueue() {
   199  	go func() {
   200  		d.queueLock.Lock()
   201  		defer d.queueLock.Unlock()
   202  		for {
   203  			select {
   204  			case part := <-d.queue:
   205  				if part == nil {
   206  					return
   207  				}
   208  				if d.queueContext.Err() != nil {
   209  					d.finishedParts <- part
   210  					continue
   211  				}
   212  				if part.processing {
   213  					panic(part)
   214  				}
   215  				if len(part.requests) > 3 {
   216  					d.Config.LogPath(d.path, map[string]interface{}{"message": "Maxed out reties", "part": part.number})
   217  					d.finishedParts <- part
   218  				} else {
   219  					if part.Context.Err() != nil {
   220  						d.finishedParts <- part.Done()
   221  						continue
   222  					}
   223  					part.Clear()
   224  					d.globalWait.Wait()
   225  					d.fileManager.Submit(func() {
   226  						d.stateLog()
   227  						d.processPart(part.Start(), false)
   228  						d.globalWait.Done()
   229  					})
   230  
   231  					d.slowDownTellFirstPart(part)
   232  				}
   233  			}
   234  		}
   235  	}()
   236  }
   237  
   238  func (d *DownloadParts) slowDownTellFirstPart(part *Part) {
   239  	// One request needs to return the header for MaxConnections.
   240  	// Once there finishedParts can be tuned to that value. So slow down to give time to get that value.
   241  	if atomic.LoadUint32(&d.partsCompleted) != 0 || d.parts[0].Err() != nil {
   242  		return
   243  	}
   244  	startTime := time.Now()
   245  	timeout := startTime.Add(time.Duration((part.number)*250) * time.Millisecond)
   246  	ctx, cancel := context.WithDeadline(d.Context, timeout)
   247  	defer cancel()
   248  
   249  	for {
   250  		select {
   251  		case <-ctx.Done():
   252  			d.Config.LogPath(d.path, map[string]interface{}{"message": fmt.Sprintf("Part1 to Finish: stopped waited %v after part %v", time.Now().Sub(startTime).Truncate(time.Microsecond), part.number)})
   253  			return
   254  		default:
   255  			if atomic.LoadUint32(&d.partsCompleted) != 0 || d.parts[0].Err() != nil {
   256  				d.Config.LogPath(d.path, map[string]interface{}{"message": fmt.Sprintf("Part1 to Finish: finish after waiting %v after part %v", time.Now().Sub(startTime).Truncate(time.Microsecond), part.number)})
   257  				cancel()
   258  				return
   259  			}
   260  		}
   261  	}
   262  }
   263  
   264  func (d *DownloadParts) stateLog(extraState ...map[string]interface{}) {
   265  	d.Config.LogPath(
   266  		d.path,
   267  		lo.Assign[string, interface{}](append(extraState, d.state())...),
   268  	)
   269  }
   270  
   271  func (d *DownloadParts) state() map[string]interface{} {
   272  	return map[string]interface{}{
   273  		"RunningParts": d.fileManager.Running(),
   274  		"limit":        d.fileManager.Cap(),
   275  		"parts":        len(d.parts),
   276  		"written":      atomic.LoadInt64(&d.totalWritten),
   277  		"completed":    atomic.LoadUint32(&d.partsCompleted),
   278  	}
   279  }
   280  
   281  func (d *DownloadParts) buildParts() {
   282  	size := d.FileInfo.Size()
   283  	iter := (ByteOffset{PartSizes: lib.PartSizes}).BySize(&size)
   284  
   285  	for {
   286  		offset, next, i := iter()
   287  		d.parts = append(d.parts, (&Part{OffSet: offset, number: i + 1}).WithContext(d.Context))
   288  		if next == nil {
   289  			break
   290  		}
   291  		iter = next
   292  	}
   293  
   294  	d.finishedParts = make(chan *Part, len(d.parts))
   295  	d.queue = make(chan *Part, len(d.parts))
   296  	d.stateLog()
   297  }
   298  
   299  func (d *DownloadParts) processPart(part *Part, UnexpectedEOF bool) {
   300  	d.processRanger(part, d.File.(ReaderRange), UnexpectedEOF)
   301  }
   302  
   303  func (d *DownloadParts) processRanger(part *Part, ranger ReaderRange, UnexpectedEOF bool) {
   304  	withContext, ok := ranger.(lib.FileWithContext)
   305  	if ok {
   306  		partCtx, partCancel := context.WithCancel(part.Context)
   307  		defer partCancel()
   308  		ranger = withContext.WithContext(partCtx).(ReaderRange)
   309  	}
   310  	r, err := ranger.ReaderRange(part.off, part.len+part.off-1)
   311  	if d.requeueOnError(part, err, UnexpectedEOF) {
   312  		return
   313  	}
   314  	if f, ok := ranger.(*File); ok {
   315  		if f.MaxConnections != 0 && d.fileManager.Cap() > f.MaxConnections {
   316  			d.fileManager.Tune(f.MaxConnections)
   317  			d.stateLog(map[string]interface{}{"message": "tuning pool", "cap": d.fileManager.Cap()})
   318  		}
   319  	}
   320  	info, _ := ranger.Stat()
   321  	sizeTrustInfo, ok := info.(UntrustedSize)
   322  	if ok && sizeTrustInfo.SizeTrust() != NullSizeTrust {
   323  		d.RWMutex.Lock()
   324  		d.FileInfo = sizeTrustInfo
   325  		d.RWMutex.Unlock()
   326  	}
   327  
   328  	wn, err := lib.CopyAt(d.WriterAndAt, part.off, r)
   329  	part.bytes = wn
   330  
   331  	part.SetError(r.Close())
   332  	if sizeTrustInfo.UntrustedSize() && part.Err() != nil {
   333  		d.verifySizeAndUpdateParts(part)
   334  	}
   335  
   336  	if d.requeueOnError(part, err, UnexpectedEOF) {
   337  		return
   338  	}
   339  
   340  	atomic.AddInt64(&d.totalWritten, wn)
   341  	d.finishedParts <- part.Done()
   342  }
   343  
   344  func (d *DownloadParts) verifySizeAndUpdateParts(part *Part) {
   345  	if errors.Is(part.Err(), UntrustedSizeRangeRequestSizeSentLessThanExpected) {
   346  		d.Config.LogPath(
   347  			d.path,
   348  			map[string]interface{}{"error": part.Err(), "part": part.number},
   349  		)
   350  		d.queueCancel()
   351  		// cancelAll greater parts
   352  		for _, p := range d.parts[part.number:] {
   353  			d.Config.LogPath(
   354  				d.path,
   355  				map[string]interface{}{"message": "canceling invalid part", "part": p.number},
   356  			)
   357  			p.CancelFunc()
   358  		}
   359  		part.SetError(nil)
   360  	}
   361  }
   362  
   363  func (d *DownloadParts) requeueOnError(part *Part, err error, UnexpectedEOF bool) bool {
   364  	for _, err := range []error{err, part.Err()} {
   365  		if err != nil && !errors.Is(err, io.EOF) {
   366  			if strings.Contains(err.Error(), "stream error") {
   367  				return false
   368  			}
   369  			if UnexpectedEOF && errors.Is(err, io.ErrUnexpectedEOF) {
   370  				return false
   371  			}
   372  			if part.error == nil {
   373  				part.SetError(err)
   374  			}
   375  			d.Config.LogPath(
   376  				d.path,
   377  				map[string]interface{}{"message": "requeuing", "error": part.Err(), "part": part.number},
   378  			)
   379  			progressWriter, ok := d.WriterAndAt.(lib.ProgressWriter)
   380  			if ok {
   381  				progressWriter.ProgressWatcher(-part.bytes)
   382  			}
   383  			d.queue <- part.Done() // either timeout or stream error try part again.
   384  			return true
   385  		}
   386  	}
   387  
   388  	return false
   389  }
   390  
   391  func (d *DownloadParts) downloadFile() error {
   392  	withContext, ok := d.File.(lib.FileWithContext)
   393  	if ok {
   394  		d.File = withContext.WithContext(d.Context)
   395  	}
   396  	n, err := io.Copy(d.WriterAndAt, d.File)
   397  	atomic.AddInt64(&d.totalWritten, n)
   398  	if err != nil {
   399  		return err
   400  	}
   401  	err = d.File.Close()
   402  	if err != nil {
   403  		return err
   404  	}
   405  
   406  	info, _ := d.File.Stat()
   407  	sizeTrustInfo, ok := info.(UntrustedSize)
   408  	if ok && sizeTrustInfo.SizeTrust() != NullSizeTrust {
   409  		d.FileInfo = sizeTrustInfo
   410  	}
   411  
   412  	if d.FileInfo.Size() != atomic.LoadInt64(&d.totalWritten) {
   413  		return fmt.Errorf("server reported size does not match downloaded file. - expected: %v, actual: %v", d.FileInfo.Size(), atomic.LoadInt64(&d.totalWritten))
   414  	}
   415  	return nil
   416  }