github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ext/dload/task.go (about)

     1  // Package dload implements functionality to download resources into AIS cluster from external source.
     2  /*
     3   * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package dload
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"os"
    14  	"time"
    15  
    16  	"github.com/NVIDIA/aistore/cmn"
    17  	"github.com/NVIDIA/aistore/cmn/atomic"
    18  	"github.com/NVIDIA/aistore/cmn/cos"
    19  	"github.com/NVIDIA/aistore/cmn/nlog"
    20  	"github.com/NVIDIA/aistore/core"
    21  	"github.com/NVIDIA/aistore/nl"
    22  	"github.com/NVIDIA/aistore/stats"
    23  )
    24  
    25  const (
    26  	gcsUA = "gcloud-golang-storage/20151204" // from cloud.google.com/go/storage/storage.go (userAgent).
    27  )
    28  
    29  const (
    30  	retryCnt         = 10  // number of retries to external resource
    31  	reqTimeoutFactor = 1.2 // newTimeout = prevTimeout * reqTimeoutFactor
    32  	internalErrorMsg = "internal server error"
    33  )
    34  
    35  type singleTask struct {
    36  	xdl         *Xact
    37  	job         jobif
    38  	obj         dlObj
    39  	started     atomic.Time
    40  	ended       atomic.Time
    41  	currentSize atomic.Int64       // current file size (updated as the download progresses)
    42  	totalSize   atomic.Int64       // total size (nonzero iff Content-Length header was provided by the source)
    43  	downloadCtx context.Context    // w/ cancel function
    44  	getCtx      context.Context    // w/ timeout and size
    45  	cancel      context.CancelFunc // to cancel in-progress download
    46  }
    47  
    48  // List of HTTP status codes which we shouldn'task retry (just report the job failed).
    49  var terminalStatuses = map[int]struct{}{
    50  	http.StatusNotFound:          {},
    51  	http.StatusPaymentRequired:   {},
    52  	http.StatusUnauthorized:      {},
    53  	http.StatusForbidden:         {},
    54  	http.StatusMethodNotAllowed:  {},
    55  	http.StatusNotAcceptable:     {},
    56  	http.StatusProxyAuthRequired: {},
    57  	http.StatusGone:              {},
    58  }
    59  
    60  ////////////////
    61  // singleTask //
    62  ////////////////
    63  
    64  func (task *singleTask) init() {
    65  	// NOTE: `cancel` is called on abort or when download finishes.
    66  	task.downloadCtx, task.cancel = context.WithCancel(context.Background())
    67  }
    68  
    69  func (task *singleTask) download(lom *core.LOM) {
    70  	err := lom.InitBck(task.job.Bck())
    71  	if err == nil {
    72  		err = lom.Load(true /*cache it*/, false /*locked*/)
    73  	}
    74  	if err != nil && !os.IsNotExist(err) {
    75  		task.markFailed(internalErrorMsg)
    76  		return
    77  	}
    78  
    79  	if cmn.Rom.FastV(4, cos.SmoduleDload) {
    80  		nlog.Infof("Starting download for %v", task)
    81  	}
    82  
    83  	task.started.Store(time.Now())
    84  	lom.SetAtimeUnix(task.started.Load().UnixNano())
    85  	if task.obj.fromRemote {
    86  		err = task.downloadRemote(lom)
    87  	} else {
    88  		err = task.downloadLocal(lom)
    89  	}
    90  	task.ended.Store(time.Now())
    91  
    92  	if err != nil {
    93  		task.markFailed(err.Error())
    94  		return
    95  	}
    96  
    97  	g.store.incFinished(task.jobID())
    98  
    99  	g.tstats.AddMany(
   100  		cos.NamedVal64{Name: stats.DownloadSize, Value: task.currentSize.Load()},
   101  		cos.NamedVal64{Name: stats.DownloadLatency, Value: int64(task.ended.Load().Sub(task.started.Load()))},
   102  	)
   103  	task.xdl.ObjsAdd(1, task.currentSize.Load())
   104  }
   105  
   106  func (task *singleTask) _dlocal(lom *core.LOM, timeout time.Duration) (bool /*err is fatal*/, error) {
   107  	ctx, cancel := context.WithTimeout(task.downloadCtx, timeout)
   108  	defer cancel()
   109  
   110  	task.getCtx = ctx
   111  
   112  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, task.obj.link, http.NoBody)
   113  	if err != nil {
   114  		return true, err
   115  	}
   116  
   117  	// Set "User-Agent" header when doing requests to Google Cloud Storage.
   118  	// This should increase the number of connections to GCS.
   119  	if cos.IsGoogleStorageURL(req.URL) {
   120  		req.Header.Add("User-Agent", gcsUA)
   121  	}
   122  
   123  	resp, err := clientForURL(task.obj.link).Do(req) //nolint:bodyclose // cos.Close
   124  	if err != nil {
   125  		return false, err
   126  	}
   127  
   128  	fatal, err := task._dput(lom, req, resp)
   129  	cos.Close(resp.Body)
   130  	return fatal, err
   131  }
   132  
   133  func (task *singleTask) _dput(lom *core.LOM, req *http.Request, resp *http.Response) (bool /*err is fatal*/, error) {
   134  	if resp.StatusCode >= http.StatusBadRequest {
   135  		if resp.StatusCode == http.StatusNotFound {
   136  			return false, cmn.NewErrHTTP(req, fmt.Errorf("%q does not exist", task.obj.link), http.StatusNotFound)
   137  		}
   138  		return false, cmn.NewErrHTTP(req,
   139  			fmt.Errorf("failed to download %q: status %d", task.obj.link, resp.StatusCode),
   140  			resp.StatusCode)
   141  	}
   142  
   143  	r := task.wrapReader(resp.Body)
   144  	size := attrsFromLink(task.obj.link, resp, lom)
   145  	task.setTotalSize(size)
   146  
   147  	params := core.AllocPutParams()
   148  	{
   149  		params.WorkTag = "dl"
   150  		params.Reader = r
   151  		params.OWT = cmn.OwtPut
   152  		params.Atime = task.started.Load()
   153  		params.Size = size
   154  		params.Xact = task.xdl
   155  	}
   156  	erp := core.T.PutObject(lom, params)
   157  	core.FreePutParams(params)
   158  	if erp != nil {
   159  		return true, erp
   160  	}
   161  	if err := lom.Load(true /*cache it*/, false /*locked*/); err != nil {
   162  		return true, err
   163  	}
   164  	return false, nil
   165  }
   166  
   167  func (task *singleTask) downloadLocal(lom *core.LOM) (err error) {
   168  	var (
   169  		timeout = task.initialTimeout()
   170  		fatal   bool
   171  	)
   172  	for i := range retryCnt {
   173  		fatal, err = task._dlocal(lom, timeout)
   174  		if err == nil || fatal {
   175  			return err
   176  		}
   177  
   178  		// handle more
   179  		if errors.Is(err, context.Canceled) || errors.Is(err, errThrottlerStopped) {
   180  			return err // canceled or stopped, so just return
   181  		}
   182  		if errors.Is(err, context.DeadlineExceeded) {
   183  			nlog.Warningf("%s [retries: %d/%d]: timeout (%v) - increasing and retrying", task, i, retryCnt, timeout)
   184  			timeout = time.Duration(float64(timeout) * reqTimeoutFactor)
   185  		} else if herr := cmn.Err2HTTPErr(err); herr != nil {
   186  			nlog.Warningf("%s [retries: %d/%d]: failed to perform request: %v (code: %d)", task, i, retryCnt, err, herr.Status)
   187  			if _, exists := terminalStatuses[herr.Status]; exists {
   188  				return err // nothing we can do
   189  			}
   190  		} else {
   191  			if !cos.IsRetriableConnErr(err) {
   192  				return err // ditto
   193  			}
   194  			nlog.Warningf("%s [retries: %d/%d]: connection failed with (%v), retrying...", task, i, retryCnt, err)
   195  		}
   196  		task.reset()
   197  	}
   198  	return err
   199  }
   200  
   201  func (task *singleTask) setTotalSize(size int64) {
   202  	if size > 0 {
   203  		task.totalSize.Store(size)
   204  	}
   205  }
   206  
   207  func (task *singleTask) reset() {
   208  	task.totalSize.Store(0)
   209  	task.currentSize.Store(0)
   210  }
   211  
   212  func (task *singleTask) downloadRemote(lom *core.LOM) error {
   213  	// Set custom context values (used by `ais/backend/*`).
   214  	ctx, cancel := context.WithTimeout(task.downloadCtx, task.initialTimeout())
   215  	defer cancel()
   216  
   217  	ctx = context.WithValue(ctx, cos.CtxReadWrapper, cos.ReadWrapperFunc(task.wrapReader))
   218  	ctx = context.WithValue(ctx, cos.CtxSetSize, cos.SetSizeFunc(task.setTotalSize))
   219  	task.getCtx = ctx
   220  
   221  	// Do final GET (prefetch) request.
   222  	_, err := core.T.GetCold(ctx, lom, cmn.OwtGetTryLock)
   223  	return err
   224  }
   225  
   226  func (task *singleTask) initialTimeout() time.Duration {
   227  	config := cmn.GCO.Get()
   228  	timeout := config.Downloader.Timeout.D()
   229  	if task.job.Timeout() != 0 {
   230  		timeout = task.job.Timeout()
   231  	}
   232  	return timeout
   233  }
   234  
   235  func (task *singleTask) wrapReader(r io.ReadCloser) io.ReadCloser {
   236  	// Create a custom reader to monitor progress every time we read from response body stream.
   237  	r = &progressReader{
   238  		r: r,
   239  		reporter: func(n int64) {
   240  			task.currentSize.Add(n)
   241  			nl.OnProgress(task.job.Notif())
   242  		},
   243  	}
   244  	// Wrap around throttler reader (noop if throttling is disabled).
   245  	r = task.job.throttler().wrapReader(task.getCtx, r)
   246  	return r
   247  }
   248  
   249  // Probably we need to extend the persistent database (db.go) so that it will contain
   250  // also information about specific tasks.
   251  func (task *singleTask) markFailed(statusMsg string) {
   252  	g.tstats.IncErr(stats.ErrDownloadCount)
   253  	g.store.persistError(task.jobID(), task.obj.objName, statusMsg)
   254  	g.store.incErrorCnt(task.jobID())
   255  }
   256  
   257  func (task *singleTask) persist() {
   258  	if err := g.store.persistTaskInfo(task); err != nil {
   259  		nlog.Errorln(err)
   260  	}
   261  }
   262  
   263  func (task *singleTask) jobID() string { return task.job.ID() }
   264  
   265  func (task *singleTask) uid() string {
   266  	return fmt.Sprintf("%s|%s|%s|%v", task.obj.link, task.job.Bck(), task.obj.objName, task.obj.fromRemote)
   267  }
   268  
   269  func (task *singleTask) ToTaskDlInfo() TaskDlInfo {
   270  	ended := task.ended.Load()
   271  	return TaskDlInfo{
   272  		Name:       task.obj.objName,
   273  		Downloaded: task.currentSize.Load(),
   274  		Total:      task.totalSize.Load(),
   275  		StartTime:  task.started.Load(),
   276  		EndTime:    ended,
   277  	}
   278  }
   279  
   280  func (task *singleTask) String() (str string) {
   281  	return fmt.Sprintf(
   282  		"{id: %q, obj_name: %q, link: %q, from_remote: %v, bucket: %q}",
   283  		task.jobID(), task.obj.objName, task.obj.link, task.obj.fromRemote, task.job.Bck(),
   284  	)
   285  }