github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/bench/tools/aisloader/client.go (about) 1 // Package aisloader 2 /* 3 * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved. 4 */ 5 6 package aisloader 7 8 import ( 9 "context" 10 "fmt" 11 "io" 12 "net/http" 13 "net/http/httptrace" 14 "net/url" 15 "os" 16 "time" 17 18 "github.com/NVIDIA/aistore/api" 19 "github.com/NVIDIA/aistore/api/apc" 20 "github.com/NVIDIA/aistore/api/env" 21 "github.com/NVIDIA/aistore/cmn" 22 "github.com/NVIDIA/aistore/cmn/cos" 23 "github.com/NVIDIA/aistore/cmn/debug" 24 "github.com/NVIDIA/aistore/cmn/mono" 25 "github.com/aws/aws-sdk-go-v2/aws" 26 "github.com/aws/aws-sdk-go-v2/config" 27 s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 28 "github.com/aws/aws-sdk-go-v2/service/s3" 29 ) 30 31 const longListTime = 10 * time.Second // list-objects progress 32 33 var ( 34 // see related command-line: `transportArgs.Timeout` and UseHTTPS 35 cargs = cmn.TransportArgs{ 36 UseHTTPProxyEnv: true, 37 } 38 // NOTE: client X509 certificate and other `cmn.TLSArgs` variables can be provided via (os.Getenv) environment. 39 // See also: 40 // - docs/aisloader.md, section "Environment variables" 41 // - AIS_ENDPOINT and aisEndpoint 42 sargs = cmn.TLSArgs{ 43 SkipVerify: true, 44 } 45 ) 46 47 type ( 48 // traceableTransport is an http.RoundTripper that keeps track of a http 49 // request and implements hooks to report HTTP tracing events. 50 traceableTransport struct { 51 transport *http.Transport 52 current *http.Request 53 tsBegin time.Time // request initialized 54 tsProxyConn time.Time // connected with proxy 55 tsRedirect time.Time // redirected 56 tsTargetConn time.Time // connected with target 57 tsHTTPEnd time.Time // http request returned 58 tsProxyWroteHeaders time.Time 59 tsProxyWroteRequest time.Time 60 tsProxyFirstResponse time.Time 61 tsTargetWroteHeaders time.Time 62 tsTargetWroteRequest time.Time 63 tsTargetFirstResponse time.Time 64 connCnt int 65 } 66 67 traceCtx struct { 68 tr *traceableTransport 69 trace *httptrace.ClientTrace 70 tracedClient *http.Client 71 } 72 tracePutter struct { 73 tctx *traceCtx 74 cksum *cos.Cksum 75 reader cos.ReadOpenCloser 76 } 77 78 // httpLatencies stores latency of a http request 79 httpLatencies struct { 80 ProxyConn time.Duration // from (request is created) to (proxy connection is established) 81 Proxy time.Duration // from (proxy connection is established) to redirected 82 TargetConn time.Duration // from (request is redirected) to (target connection is established) 83 Target time.Duration // from (target connection is established) to (request is completed) 84 PostHTTP time.Duration // from http ends to after read data from http response and verify hash (if specified) 85 ProxyWroteHeader time.Duration // from ProxyConn to header is written 86 ProxyWroteRequest time.Duration // from ProxyWroteHeader to response body is written 87 ProxyFirstResponse time.Duration // from ProxyWroteRequest to first byte of response 88 TargetWroteHeader time.Duration // from TargetConn to header is written 89 TargetWroteRequest time.Duration // from TargetWroteHeader to response body is written 90 TargetFirstResponse time.Duration // from TargetWroteRequest to first byte of response 91 } 92 ) 93 94 //////////////////////// 95 // traceableTransport // 96 //////////////////////// 97 98 // RoundTrip records the proxy redirect time and keeps track of requests. 99 func (t *traceableTransport) RoundTrip(req *http.Request) (*http.Response, error) { 100 if t.connCnt == 1 { 101 t.tsRedirect = time.Now() 102 } 103 104 t.current = req 105 return t.transport.RoundTrip(req) 106 } 107 108 // GotConn records when the connection to proxy/target is made. 109 func (t *traceableTransport) GotConn(httptrace.GotConnInfo) { 110 switch t.connCnt { 111 case 0: 112 t.tsProxyConn = time.Now() 113 case 1: 114 t.tsTargetConn = time.Now() 115 default: 116 // ignore 117 // this can happen during proxy stress test when the proxy dies 118 } 119 t.connCnt++ 120 } 121 122 // WroteHeaders records when the header is written to 123 func (t *traceableTransport) WroteHeaders() { 124 switch t.connCnt { 125 case 1: 126 t.tsProxyWroteHeaders = time.Now() 127 case 2: 128 t.tsTargetWroteHeaders = time.Now() 129 default: 130 // ignore 131 } 132 } 133 134 // WroteRequest records when the request is completely written 135 func (t *traceableTransport) WroteRequest(httptrace.WroteRequestInfo) { 136 switch t.connCnt { 137 case 1: 138 t.tsProxyWroteRequest = time.Now() 139 case 2: 140 t.tsTargetWroteRequest = time.Now() 141 default: 142 // ignore 143 } 144 } 145 146 // GotFirstResponseByte records when the response starts to come back 147 func (t *traceableTransport) GotFirstResponseByte() { 148 switch t.connCnt { 149 case 1: 150 t.tsProxyFirstResponse = time.Now() 151 case 2: 152 t.tsTargetFirstResponse = time.Now() 153 default: 154 // ignore 155 } 156 } 157 158 func (t *traceableTransport) set(l *httpLatencies) { 159 l.ProxyConn = timeDelta(t.tsProxyConn, t.tsBegin) 160 l.Proxy = timeDelta(t.tsRedirect, t.tsProxyConn) 161 l.TargetConn = timeDelta(t.tsTargetConn, t.tsRedirect) 162 l.Target = timeDelta(t.tsHTTPEnd, t.tsTargetConn) 163 l.PostHTTP = time.Since(t.tsHTTPEnd) 164 l.ProxyWroteHeader = timeDelta(t.tsProxyWroteHeaders, t.tsProxyConn) 165 l.ProxyWroteRequest = timeDelta(t.tsProxyWroteRequest, t.tsProxyWroteHeaders) 166 l.ProxyFirstResponse = timeDelta(t.tsProxyFirstResponse, t.tsProxyWroteRequest) 167 l.TargetWroteHeader = timeDelta(t.tsTargetWroteHeaders, t.tsTargetConn) 168 l.TargetWroteRequest = timeDelta(t.tsTargetWroteRequest, t.tsTargetWroteHeaders) 169 l.TargetFirstResponse = timeDelta(t.tsTargetFirstResponse, t.tsTargetWroteRequest) 170 } 171 172 ////////////////////////////////// 173 // detailed http trace _putter_ // 174 ////////////////////////////////// 175 176 // implements callback of the type `api.NewRequestCB` 177 func (putter *tracePutter) do(reqArgs *cmn.HreqArgs) (*http.Request, error) { 178 req, err := reqArgs.Req() 179 if err != nil { 180 return nil, err 181 } 182 183 // The HTTP package doesn't automatically set this for files, so it has to be done manually 184 // If it wasn't set, we would need to deal with the redirect manually. 185 req.GetBody = func() (io.ReadCloser, error) { 186 return putter.reader.Open() 187 } 188 if putter.cksum != nil { 189 req.Header.Set(apc.HdrObjCksumType, putter.cksum.Ty()) 190 req.Header.Set(apc.HdrObjCksumVal, putter.cksum.Val()) 191 } 192 return req.WithContext(httptrace.WithClientTrace(req.Context(), putter.tctx.trace)), nil 193 } 194 195 // a bare-minimum (e.g. not passing checksum or any other metadata) 196 func s3put(bck cmn.Bck, objName string, reader cos.ReadOpenCloser) (err error) { 197 uploader := s3manager.NewUploader(s3svc) 198 _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 199 Bucket: aws.String(bck.Name), 200 Key: aws.String(objName), 201 Body: reader, 202 }) 203 erc := reader.Close() 204 debug.AssertNoErr(erc) 205 return 206 } 207 208 func put(proxyURL string, bck cmn.Bck, objName string, cksum *cos.Cksum, reader cos.ReadOpenCloser) (err error) { 209 var ( 210 baseParams = api.BaseParams{ 211 Client: runParams.bp.Client, 212 URL: proxyURL, 213 Method: http.MethodPut, 214 Token: loggedUserToken, 215 UA: ua, 216 } 217 args = api.PutArgs{ 218 BaseParams: baseParams, 219 Bck: bck, 220 ObjName: objName, 221 Cksum: cksum, 222 Reader: reader, 223 SkipVC: true, 224 } 225 ) 226 _, err = api.PutObject(&args) 227 return 228 } 229 230 // PUT with HTTP trace 231 func putWithTrace(proxyURL string, bck cmn.Bck, objName string, latencies *httpLatencies, cksum *cos.Cksum, reader cos.ReadOpenCloser) error { 232 reqArgs := cmn.AllocHra() 233 { 234 reqArgs.Method = http.MethodPut 235 reqArgs.Base = proxyURL 236 reqArgs.Path = apc.URLPathObjects.Join(bck.Name, objName) 237 reqArgs.Query = bck.NewQuery() 238 reqArgs.BodyR = reader 239 } 240 putter := tracePutter{ 241 tctx: newTraceCtx(proxyURL), 242 cksum: cksum, 243 reader: reader, 244 } 245 _, err := api.DoWithRetry(putter.tctx.tracedClient, putter.do, reqArgs) //nolint:bodyclose // it's closed inside 246 cmn.FreeHra(reqArgs) 247 if err != nil { 248 return err 249 } 250 tctx := putter.tctx 251 tctx.tr.tsHTTPEnd = time.Now() 252 253 tctx.tr.set(latencies) 254 return nil 255 } 256 257 func newTraceCtx(proxyURL string) *traceCtx { 258 var ( 259 tctx = &traceCtx{} 260 transport = cmn.NewTransport(cargs) 261 err error 262 ) 263 if cos.IsHTTPS(proxyURL) { 264 transport.TLSClientConfig, err = cmn.NewTLS(sargs) 265 cos.AssertNoErr(err) 266 } 267 tctx.tr = &traceableTransport{ 268 transport: transport, 269 tsBegin: time.Now(), 270 } 271 tctx.trace = &httptrace.ClientTrace{ 272 GotConn: tctx.tr.GotConn, 273 WroteHeaders: tctx.tr.WroteHeaders, 274 WroteRequest: tctx.tr.WroteRequest, 275 GotFirstResponseByte: tctx.tr.GotFirstResponseByte, 276 } 277 tctx.tracedClient = &http.Client{ 278 Transport: tctx.tr, 279 Timeout: 600 * time.Second, 280 } 281 return tctx 282 } 283 284 func newGetRequest(proxyURL string, bck cmn.Bck, objName string, offset, length int64, latest bool) (*http.Request, error) { 285 var ( 286 hdr http.Header 287 query = url.Values{} 288 ) 289 query = bck.AddToQuery(query) 290 if etlName != "" { 291 query.Add(apc.QparamETLName, etlName) 292 } 293 if latest { 294 query.Add(apc.QparamLatestVer, "true") 295 } 296 if length > 0 { 297 rng := cmn.MakeRangeHdr(offset, length) 298 hdr = http.Header{cos.HdrRange: []string{rng}} 299 } 300 reqArgs := cmn.HreqArgs{ 301 Method: http.MethodGet, 302 Base: proxyURL, 303 Path: apc.URLPathObjects.Join(bck.Name, objName), 304 Query: query, 305 Header: hdr, 306 } 307 return reqArgs.Req() 308 } 309 310 func s3getDiscard(bck cmn.Bck, objName string) (int64, error) { 311 obj, err := s3svc.GetObject(context.Background(), &s3.GetObjectInput{ 312 Bucket: aws.String(bck.Name), 313 Key: aws.String(objName), 314 }) 315 if err != nil { 316 if obj != nil && obj.Body != nil { 317 io.Copy(io.Discard, obj.Body) 318 obj.Body.Close() 319 } 320 return 0, err // detailed enough 321 } 322 323 var size, n int64 324 size = *obj.ContentLength 325 n, err = io.Copy(io.Discard, obj.Body) 326 obj.Body.Close() 327 328 if err != nil { 329 return n, fmt.Errorf("failed to GET %s/%s and discard it (%d, %d): %v", bck, objName, n, size, err) 330 } 331 if n != size { 332 err = fmt.Errorf("failed to GET %s/%s: wrong size (%d, %d)", bck, objName, n, size) 333 } 334 return size, err 335 } 336 337 // getDiscard sends a GET request and discards returned data. 338 func getDiscard(proxyURL string, bck cmn.Bck, objName string, offset, length int64, validate, latest bool) (int64, error) { 339 req, err := newGetRequest(proxyURL, bck, objName, offset, length, latest) 340 if err != nil { 341 return 0, err 342 } 343 resp, err := runParams.bp.Client.Do(req) 344 if err != nil { 345 return 0, err 346 } 347 348 var hdrCksumValue, hdrCksumType string 349 if validate { 350 hdrCksumValue = resp.Header.Get(apc.HdrObjCksumVal) 351 hdrCksumType = resp.Header.Get(apc.HdrObjCksumType) 352 } 353 src := "GET " + bck.Cname(objName) 354 n, cksumValue, err := readDiscard(resp, src, hdrCksumType) 355 356 resp.Body.Close() 357 if err != nil { 358 return 0, err 359 } 360 if validate && hdrCksumValue != cksumValue { 361 return 0, cmn.NewErrInvalidCksum(hdrCksumValue, cksumValue) 362 } 363 return n, err 364 } 365 366 // Same as above, but with HTTP trace. 367 func getTraceDiscard(proxyURL string, bck cmn.Bck, objName string, latencies *httpLatencies, offset, length int64, validate, latest bool) (int64, error) { 368 var ( 369 hdrCksumValue string 370 hdrCksumType string 371 ) 372 req, err := newGetRequest(proxyURL, bck, objName, offset, length, latest) 373 if err != nil { 374 return 0, err 375 } 376 377 tctx := newTraceCtx(proxyURL) 378 req = req.WithContext(httptrace.WithClientTrace(req.Context(), tctx.trace)) 379 380 resp, err := tctx.tracedClient.Do(req) 381 if err != nil { 382 return 0, err 383 } 384 defer resp.Body.Close() 385 386 tctx.tr.tsHTTPEnd = time.Now() 387 if validate { 388 hdrCksumValue = resp.Header.Get(apc.HdrObjCksumVal) 389 hdrCksumType = resp.Header.Get(apc.HdrObjCksumType) 390 } 391 392 src := "GET " + bck.Cname(objName) 393 n, cksumValue, err := readDiscard(resp, src, hdrCksumType) 394 if err != nil { 395 return 0, err 396 } 397 if validate && hdrCksumValue != cksumValue { 398 err = cmn.NewErrInvalidCksum(hdrCksumValue, cksumValue) 399 } 400 401 tctx.tr.set(latencies) 402 return n, err 403 } 404 405 // getConfig sends a {what:config} request to the url and discard the message 406 // For testing purpose only 407 func getConfig(proxyURL string) (httpLatencies, error) { 408 tctx := newTraceCtx(proxyURL) 409 410 url := proxyURL + apc.URLPathDae.S 411 req, _ := http.NewRequest(http.MethodGet, url, http.NoBody) 412 req.URL.RawQuery = api.GetWhatRawQuery(apc.WhatNodeConfig, "") 413 req = req.WithContext(httptrace.WithClientTrace(req.Context(), tctx.trace)) 414 415 resp, err := tctx.tracedClient.Do(req) 416 if err != nil { 417 return httpLatencies{}, err 418 } 419 defer resp.Body.Close() 420 421 _, _, err = readDiscard(resp, "GetConfig", "" /*cksum type*/) 422 423 l := httpLatencies{ 424 ProxyConn: timeDelta(tctx.tr.tsProxyConn, tctx.tr.tsBegin), 425 Proxy: time.Since(tctx.tr.tsProxyConn), 426 } 427 return l, err 428 } 429 430 func listObjCallback(ctx *api.LsoCounter) { 431 if ctx.Count() < 0 { 432 return 433 } 434 fmt.Printf("\rListing %s objects", cos.FormatBigNum(ctx.Count())) 435 if ctx.IsFinished() { 436 fmt.Println() 437 } 438 } 439 440 // listObjectNames returns a slice of object names of all objects that match the prefix in a bucket. 441 func listObjectNames(baseParams api.BaseParams, bck cmn.Bck, prefix string, cached bool) ([]string, error) { 442 msg := &apc.LsoMsg{Prefix: prefix} 443 // if bck is remote then check for cached flag 444 if cached { 445 msg.Flags |= apc.LsObjCached 446 } 447 args := api.ListArgs{Callback: listObjCallback, CallAfter: longListTime} 448 objList, err := api.ListObjects(baseParams, bck, msg, args) 449 if err != nil { 450 return nil, err 451 } 452 453 objs := make([]string, 0, len(objList.Entries)) 454 for _, obj := range objList.Entries { 455 objs = append(objs, obj.Name) 456 } 457 return objs, nil 458 } 459 460 func initS3Svc() error { 461 // '--s3profile' takes precedence 462 if s3Profile == "" { 463 if profile := os.Getenv(env.AWS.Profile); profile != "" { 464 s3Profile = profile 465 } 466 } 467 cfg, err := config.LoadDefaultConfig( 468 context.Background(), 469 config.WithSharedConfigProfile(s3Profile), 470 ) 471 if err != nil { 472 return err 473 } 474 if s3Endpoint != "" { 475 cfg.BaseEndpoint = aws.String(s3Endpoint) 476 } 477 if cfg.Region == "" { 478 cfg.Region = env.AwsDefaultRegion() 479 } 480 481 s3svc = s3.NewFromConfig(cfg, func(o *s3.Options) { 482 o.UsePathStyle = s3UsePathStyle 483 }) 484 return nil 485 } 486 487 func s3ListObjects() ([]string, error) { 488 // first page 489 params := &s3.ListObjectsV2Input{Bucket: aws.String(runParams.bck.Name)} 490 params.MaxKeys = aws.Int32(apc.MaxPageSizeAWS) 491 492 prev := mono.NanoTime() 493 resp, err := s3svc.ListObjectsV2(context.Background(), params) 494 if err != nil { 495 return nil, err 496 } 497 498 var ( 499 token string 500 l = len(resp.Contents) 501 ) 502 if resp.NextContinuationToken != nil { 503 token = *resp.NextContinuationToken 504 } 505 if token != "" { 506 l = 16 * apc.MaxPageSizeAWS 507 } 508 names := make([]string, 0, l) 509 for _, object := range resp.Contents { 510 names = append(names, *object.Key) 511 } 512 if token == "" { 513 return names, nil 514 } 515 516 // get all the rest pages in one fell swoop 517 var eol bool 518 for token != "" { 519 params.ContinuationToken = &token 520 resp, err = s3svc.ListObjectsV2(context.Background(), params) 521 if err != nil { 522 return nil, err 523 } 524 for _, object := range resp.Contents { 525 names = append(names, *object.Key) 526 } 527 token = "" 528 if resp.NextContinuationToken != nil { 529 token = *resp.NextContinuationToken 530 } 531 now := mono.NanoTime() 532 if time.Duration(now-prev) >= longListTime { 533 fmt.Printf("\rListing %s objects", cos.FormatBigNum(len(names))) 534 prev = now 535 eol = true 536 } 537 } 538 if eol { 539 fmt.Println() 540 } 541 return names, nil 542 } 543 544 func readDiscard(r *http.Response, tag, cksumType string) (int64, string, error) { 545 var ( 546 n int64 547 cksum *cos.CksumHash 548 err error 549 cksumValue string 550 ) 551 if r.StatusCode >= http.StatusBadRequest { 552 bytes, err := io.ReadAll(r.Body) 553 if err == nil { 554 return 0, "", fmt.Errorf("bad status %d from %s, response: %s", r.StatusCode, tag, string(bytes)) 555 } 556 return 0, "", fmt.Errorf("bad status %d from %s: %v", r.StatusCode, tag, err) 557 } 558 n, cksum, err = cos.CopyAndChecksum(io.Discard, r.Body, nil, cksumType) 559 if err != nil { 560 return 0, "", fmt.Errorf("failed to read HTTP response, err: %v", err) 561 } 562 if cksum != nil { 563 cksumValue = cksum.Value() 564 } 565 return n, cksumValue, nil 566 } 567 568 func timeDelta(time1, time2 time.Time) time.Duration { 569 if time1.IsZero() || time2.IsZero() { 570 return 0 571 } 572 return time1.Sub(time2) 573 }