github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ext/dload/task.go (about) 1 // Package dload implements functionality to download resources into AIS cluster from external source. 2 /* 3 * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved. 4 */ 5 package dload 6 7 import ( 8 "context" 9 "errors" 10 "fmt" 11 "io" 12 "net/http" 13 "os" 14 "time" 15 16 "github.com/NVIDIA/aistore/cmn" 17 "github.com/NVIDIA/aistore/cmn/atomic" 18 "github.com/NVIDIA/aistore/cmn/cos" 19 "github.com/NVIDIA/aistore/cmn/nlog" 20 "github.com/NVIDIA/aistore/core" 21 "github.com/NVIDIA/aistore/nl" 22 "github.com/NVIDIA/aistore/stats" 23 ) 24 25 const ( 26 gcsUA = "gcloud-golang-storage/20151204" // from cloud.google.com/go/storage/storage.go (userAgent). 27 ) 28 29 const ( 30 retryCnt = 10 // number of retries to external resource 31 reqTimeoutFactor = 1.2 // newTimeout = prevTimeout * reqTimeoutFactor 32 internalErrorMsg = "internal server error" 33 ) 34 35 type singleTask struct { 36 xdl *Xact 37 job jobif 38 obj dlObj 39 started atomic.Time 40 ended atomic.Time 41 currentSize atomic.Int64 // current file size (updated as the download progresses) 42 totalSize atomic.Int64 // total size (nonzero iff Content-Length header was provided by the source) 43 downloadCtx context.Context // w/ cancel function 44 getCtx context.Context // w/ timeout and size 45 cancel context.CancelFunc // to cancel in-progress download 46 } 47 48 // List of HTTP status codes which we shouldn'task retry (just report the job failed). 49 var terminalStatuses = map[int]struct{}{ 50 http.StatusNotFound: {}, 51 http.StatusPaymentRequired: {}, 52 http.StatusUnauthorized: {}, 53 http.StatusForbidden: {}, 54 http.StatusMethodNotAllowed: {}, 55 http.StatusNotAcceptable: {}, 56 http.StatusProxyAuthRequired: {}, 57 http.StatusGone: {}, 58 } 59 60 //////////////// 61 // singleTask // 62 //////////////// 63 64 func (task *singleTask) init() { 65 // NOTE: `cancel` is called on abort or when download finishes. 66 task.downloadCtx, task.cancel = context.WithCancel(context.Background()) 67 } 68 69 func (task *singleTask) download(lom *core.LOM) { 70 err := lom.InitBck(task.job.Bck()) 71 if err == nil { 72 err = lom.Load(true /*cache it*/, false /*locked*/) 73 } 74 if err != nil && !os.IsNotExist(err) { 75 task.markFailed(internalErrorMsg) 76 return 77 } 78 79 if cmn.Rom.FastV(4, cos.SmoduleDload) { 80 nlog.Infof("Starting download for %v", task) 81 } 82 83 task.started.Store(time.Now()) 84 lom.SetAtimeUnix(task.started.Load().UnixNano()) 85 if task.obj.fromRemote { 86 err = task.downloadRemote(lom) 87 } else { 88 err = task.downloadLocal(lom) 89 } 90 task.ended.Store(time.Now()) 91 92 if err != nil { 93 task.markFailed(err.Error()) 94 return 95 } 96 97 g.store.incFinished(task.jobID()) 98 99 g.tstats.AddMany( 100 cos.NamedVal64{Name: stats.DownloadSize, Value: task.currentSize.Load()}, 101 cos.NamedVal64{Name: stats.DownloadLatency, Value: int64(task.ended.Load().Sub(task.started.Load()))}, 102 ) 103 task.xdl.ObjsAdd(1, task.currentSize.Load()) 104 } 105 106 func (task *singleTask) _dlocal(lom *core.LOM, timeout time.Duration) (bool /*err is fatal*/, error) { 107 ctx, cancel := context.WithTimeout(task.downloadCtx, timeout) 108 defer cancel() 109 110 task.getCtx = ctx 111 112 req, err := http.NewRequestWithContext(ctx, http.MethodGet, task.obj.link, http.NoBody) 113 if err != nil { 114 return true, err 115 } 116 117 // Set "User-Agent" header when doing requests to Google Cloud Storage. 118 // This should increase the number of connections to GCS. 119 if cos.IsGoogleStorageURL(req.URL) { 120 req.Header.Add("User-Agent", gcsUA) 121 } 122 123 resp, err := clientForURL(task.obj.link).Do(req) //nolint:bodyclose // cos.Close 124 if err != nil { 125 return false, err 126 } 127 128 fatal, err := task._dput(lom, req, resp) 129 cos.Close(resp.Body) 130 return fatal, err 131 } 132 133 func (task *singleTask) _dput(lom *core.LOM, req *http.Request, resp *http.Response) (bool /*err is fatal*/, error) { 134 if resp.StatusCode >= http.StatusBadRequest { 135 if resp.StatusCode == http.StatusNotFound { 136 return false, cmn.NewErrHTTP(req, fmt.Errorf("%q does not exist", task.obj.link), http.StatusNotFound) 137 } 138 return false, cmn.NewErrHTTP(req, 139 fmt.Errorf("failed to download %q: status %d", task.obj.link, resp.StatusCode), 140 resp.StatusCode) 141 } 142 143 r := task.wrapReader(resp.Body) 144 size := attrsFromLink(task.obj.link, resp, lom) 145 task.setTotalSize(size) 146 147 params := core.AllocPutParams() 148 { 149 params.WorkTag = "dl" 150 params.Reader = r 151 params.OWT = cmn.OwtPut 152 params.Atime = task.started.Load() 153 params.Size = size 154 params.Xact = task.xdl 155 } 156 erp := core.T.PutObject(lom, params) 157 core.FreePutParams(params) 158 if erp != nil { 159 return true, erp 160 } 161 if err := lom.Load(true /*cache it*/, false /*locked*/); err != nil { 162 return true, err 163 } 164 return false, nil 165 } 166 167 func (task *singleTask) downloadLocal(lom *core.LOM) (err error) { 168 var ( 169 timeout = task.initialTimeout() 170 fatal bool 171 ) 172 for i := range retryCnt { 173 fatal, err = task._dlocal(lom, timeout) 174 if err == nil || fatal { 175 return err 176 } 177 178 // handle more 179 if errors.Is(err, context.Canceled) || errors.Is(err, errThrottlerStopped) { 180 return err // canceled or stopped, so just return 181 } 182 if errors.Is(err, context.DeadlineExceeded) { 183 nlog.Warningf("%s [retries: %d/%d]: timeout (%v) - increasing and retrying", task, i, retryCnt, timeout) 184 timeout = time.Duration(float64(timeout) * reqTimeoutFactor) 185 } else if herr := cmn.Err2HTTPErr(err); herr != nil { 186 nlog.Warningf("%s [retries: %d/%d]: failed to perform request: %v (code: %d)", task, i, retryCnt, err, herr.Status) 187 if _, exists := terminalStatuses[herr.Status]; exists { 188 return err // nothing we can do 189 } 190 } else { 191 if !cos.IsRetriableConnErr(err) { 192 return err // ditto 193 } 194 nlog.Warningf("%s [retries: %d/%d]: connection failed with (%v), retrying...", task, i, retryCnt, err) 195 } 196 task.reset() 197 } 198 return err 199 } 200 201 func (task *singleTask) setTotalSize(size int64) { 202 if size > 0 { 203 task.totalSize.Store(size) 204 } 205 } 206 207 func (task *singleTask) reset() { 208 task.totalSize.Store(0) 209 task.currentSize.Store(0) 210 } 211 212 func (task *singleTask) downloadRemote(lom *core.LOM) error { 213 // Set custom context values (used by `ais/backend/*`). 214 ctx, cancel := context.WithTimeout(task.downloadCtx, task.initialTimeout()) 215 defer cancel() 216 217 ctx = context.WithValue(ctx, cos.CtxReadWrapper, cos.ReadWrapperFunc(task.wrapReader)) 218 ctx = context.WithValue(ctx, cos.CtxSetSize, cos.SetSizeFunc(task.setTotalSize)) 219 task.getCtx = ctx 220 221 // Do final GET (prefetch) request. 222 _, err := core.T.GetCold(ctx, lom, cmn.OwtGetTryLock) 223 return err 224 } 225 226 func (task *singleTask) initialTimeout() time.Duration { 227 config := cmn.GCO.Get() 228 timeout := config.Downloader.Timeout.D() 229 if task.job.Timeout() != 0 { 230 timeout = task.job.Timeout() 231 } 232 return timeout 233 } 234 235 func (task *singleTask) wrapReader(r io.ReadCloser) io.ReadCloser { 236 // Create a custom reader to monitor progress every time we read from response body stream. 237 r = &progressReader{ 238 r: r, 239 reporter: func(n int64) { 240 task.currentSize.Add(n) 241 nl.OnProgress(task.job.Notif()) 242 }, 243 } 244 // Wrap around throttler reader (noop if throttling is disabled). 245 r = task.job.throttler().wrapReader(task.getCtx, r) 246 return r 247 } 248 249 // Probably we need to extend the persistent database (db.go) so that it will contain 250 // also information about specific tasks. 251 func (task *singleTask) markFailed(statusMsg string) { 252 g.tstats.IncErr(stats.ErrDownloadCount) 253 g.store.persistError(task.jobID(), task.obj.objName, statusMsg) 254 g.store.incErrorCnt(task.jobID()) 255 } 256 257 func (task *singleTask) persist() { 258 if err := g.store.persistTaskInfo(task); err != nil { 259 nlog.Errorln(err) 260 } 261 } 262 263 func (task *singleTask) jobID() string { return task.job.ID() } 264 265 func (task *singleTask) uid() string { 266 return fmt.Sprintf("%s|%s|%s|%v", task.obj.link, task.job.Bck(), task.obj.objName, task.obj.fromRemote) 267 } 268 269 func (task *singleTask) ToTaskDlInfo() TaskDlInfo { 270 ended := task.ended.Load() 271 return TaskDlInfo{ 272 Name: task.obj.objName, 273 Downloaded: task.currentSize.Load(), 274 Total: task.totalSize.Load(), 275 StartTime: task.started.Load(), 276 EndTime: ended, 277 } 278 } 279 280 func (task *singleTask) String() (str string) { 281 return fmt.Sprintf( 282 "{id: %q, obj_name: %q, link: %q, from_remote: %v, bucket: %q}", 283 task.jobID(), task.obj.objName, task.obj.link, task.obj.fromRemote, task.job.Bck(), 284 ) 285 }