github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/bench/tools/aisloader/client.go (about)

     1  // Package aisloader
     2  /*
     3   * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  
     6  package aisloader
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"net/http/httptrace"
    14  	"net/url"
    15  	"os"
    16  	"time"
    17  
    18  	"github.com/NVIDIA/aistore/api"
    19  	"github.com/NVIDIA/aistore/api/apc"
    20  	"github.com/NVIDIA/aistore/api/env"
    21  	"github.com/NVIDIA/aistore/cmn"
    22  	"github.com/NVIDIA/aistore/cmn/cos"
    23  	"github.com/NVIDIA/aistore/cmn/debug"
    24  	"github.com/NVIDIA/aistore/cmn/mono"
    25  	"github.com/aws/aws-sdk-go-v2/aws"
    26  	"github.com/aws/aws-sdk-go-v2/config"
    27  	s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
    28  	"github.com/aws/aws-sdk-go-v2/service/s3"
    29  )
    30  
    31  const longListTime = 10 * time.Second // list-objects progress
    32  
    33  var (
    34  	// see related command-line: `transportArgs.Timeout` and UseHTTPS
    35  	cargs = cmn.TransportArgs{
    36  		UseHTTPProxyEnv: true,
    37  	}
    38  	// NOTE: client X509 certificate and other `cmn.TLSArgs` variables can be provided via (os.Getenv) environment.
    39  	// See also:
    40  	// - docs/aisloader.md, section "Environment variables"
    41  	// - AIS_ENDPOINT and aisEndpoint
    42  	sargs = cmn.TLSArgs{
    43  		SkipVerify: true,
    44  	}
    45  )
    46  
    47  type (
    48  	// traceableTransport is an http.RoundTripper that keeps track of a http
    49  	// request and implements hooks to report HTTP tracing events.
    50  	traceableTransport struct {
    51  		transport             *http.Transport
    52  		current               *http.Request
    53  		tsBegin               time.Time // request initialized
    54  		tsProxyConn           time.Time // connected with proxy
    55  		tsRedirect            time.Time // redirected
    56  		tsTargetConn          time.Time // connected with target
    57  		tsHTTPEnd             time.Time // http request returned
    58  		tsProxyWroteHeaders   time.Time
    59  		tsProxyWroteRequest   time.Time
    60  		tsProxyFirstResponse  time.Time
    61  		tsTargetWroteHeaders  time.Time
    62  		tsTargetWroteRequest  time.Time
    63  		tsTargetFirstResponse time.Time
    64  		connCnt               int
    65  	}
    66  
    67  	traceCtx struct {
    68  		tr           *traceableTransport
    69  		trace        *httptrace.ClientTrace
    70  		tracedClient *http.Client
    71  	}
    72  	tracePutter struct {
    73  		tctx   *traceCtx
    74  		cksum  *cos.Cksum
    75  		reader cos.ReadOpenCloser
    76  	}
    77  
    78  	// httpLatencies stores latency of a http request
    79  	httpLatencies struct {
    80  		ProxyConn           time.Duration // from (request is created) to (proxy connection is established)
    81  		Proxy               time.Duration // from (proxy connection is established) to redirected
    82  		TargetConn          time.Duration // from (request is redirected) to (target connection is established)
    83  		Target              time.Duration // from (target connection is established) to (request is completed)
    84  		PostHTTP            time.Duration // from http ends to after read data from http response and verify hash (if specified)
    85  		ProxyWroteHeader    time.Duration // from ProxyConn to header is written
    86  		ProxyWroteRequest   time.Duration // from ProxyWroteHeader to response body is written
    87  		ProxyFirstResponse  time.Duration // from ProxyWroteRequest to first byte of response
    88  		TargetWroteHeader   time.Duration // from TargetConn to header is written
    89  		TargetWroteRequest  time.Duration // from TargetWroteHeader to response body is written
    90  		TargetFirstResponse time.Duration // from TargetWroteRequest to first byte of response
    91  	}
    92  )
    93  
    94  ////////////////////////
    95  // traceableTransport //
    96  ////////////////////////
    97  
    98  // RoundTrip records the proxy redirect time and keeps track of requests.
    99  func (t *traceableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   100  	if t.connCnt == 1 {
   101  		t.tsRedirect = time.Now()
   102  	}
   103  
   104  	t.current = req
   105  	return t.transport.RoundTrip(req)
   106  }
   107  
   108  // GotConn records when the connection to proxy/target is made.
   109  func (t *traceableTransport) GotConn(httptrace.GotConnInfo) {
   110  	switch t.connCnt {
   111  	case 0:
   112  		t.tsProxyConn = time.Now()
   113  	case 1:
   114  		t.tsTargetConn = time.Now()
   115  	default:
   116  		// ignore
   117  		// this can happen during proxy stress test when the proxy dies
   118  	}
   119  	t.connCnt++
   120  }
   121  
   122  // WroteHeaders records when the header is written to
   123  func (t *traceableTransport) WroteHeaders() {
   124  	switch t.connCnt {
   125  	case 1:
   126  		t.tsProxyWroteHeaders = time.Now()
   127  	case 2:
   128  		t.tsTargetWroteHeaders = time.Now()
   129  	default:
   130  		// ignore
   131  	}
   132  }
   133  
   134  // WroteRequest records when the request is completely written
   135  func (t *traceableTransport) WroteRequest(httptrace.WroteRequestInfo) {
   136  	switch t.connCnt {
   137  	case 1:
   138  		t.tsProxyWroteRequest = time.Now()
   139  	case 2:
   140  		t.tsTargetWroteRequest = time.Now()
   141  	default:
   142  		// ignore
   143  	}
   144  }
   145  
   146  // GotFirstResponseByte records when the response starts to come back
   147  func (t *traceableTransport) GotFirstResponseByte() {
   148  	switch t.connCnt {
   149  	case 1:
   150  		t.tsProxyFirstResponse = time.Now()
   151  	case 2:
   152  		t.tsTargetFirstResponse = time.Now()
   153  	default:
   154  		// ignore
   155  	}
   156  }
   157  
   158  func (t *traceableTransport) set(l *httpLatencies) {
   159  	l.ProxyConn = timeDelta(t.tsProxyConn, t.tsBegin)
   160  	l.Proxy = timeDelta(t.tsRedirect, t.tsProxyConn)
   161  	l.TargetConn = timeDelta(t.tsTargetConn, t.tsRedirect)
   162  	l.Target = timeDelta(t.tsHTTPEnd, t.tsTargetConn)
   163  	l.PostHTTP = time.Since(t.tsHTTPEnd)
   164  	l.ProxyWroteHeader = timeDelta(t.tsProxyWroteHeaders, t.tsProxyConn)
   165  	l.ProxyWroteRequest = timeDelta(t.tsProxyWroteRequest, t.tsProxyWroteHeaders)
   166  	l.ProxyFirstResponse = timeDelta(t.tsProxyFirstResponse, t.tsProxyWroteRequest)
   167  	l.TargetWroteHeader = timeDelta(t.tsTargetWroteHeaders, t.tsTargetConn)
   168  	l.TargetWroteRequest = timeDelta(t.tsTargetWroteRequest, t.tsTargetWroteHeaders)
   169  	l.TargetFirstResponse = timeDelta(t.tsTargetFirstResponse, t.tsTargetWroteRequest)
   170  }
   171  
   172  //////////////////////////////////
   173  // detailed http trace _putter_ //
   174  //////////////////////////////////
   175  
   176  // implements callback of the type `api.NewRequestCB`
   177  func (putter *tracePutter) do(reqArgs *cmn.HreqArgs) (*http.Request, error) {
   178  	req, err := reqArgs.Req()
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	// The HTTP package doesn't automatically set this for files, so it has to be done manually
   184  	// If it wasn't set, we would need to deal with the redirect manually.
   185  	req.GetBody = func() (io.ReadCloser, error) {
   186  		return putter.reader.Open()
   187  	}
   188  	if putter.cksum != nil {
   189  		req.Header.Set(apc.HdrObjCksumType, putter.cksum.Ty())
   190  		req.Header.Set(apc.HdrObjCksumVal, putter.cksum.Val())
   191  	}
   192  	return req.WithContext(httptrace.WithClientTrace(req.Context(), putter.tctx.trace)), nil
   193  }
   194  
   195  // a bare-minimum (e.g. not passing checksum or any other metadata)
   196  func s3put(bck cmn.Bck, objName string, reader cos.ReadOpenCloser) (err error) {
   197  	uploader := s3manager.NewUploader(s3svc)
   198  	_, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   199  		Bucket: aws.String(bck.Name),
   200  		Key:    aws.String(objName),
   201  		Body:   reader,
   202  	})
   203  	erc := reader.Close()
   204  	debug.AssertNoErr(erc)
   205  	return
   206  }
   207  
   208  func put(proxyURL string, bck cmn.Bck, objName string, cksum *cos.Cksum, reader cos.ReadOpenCloser) (err error) {
   209  	var (
   210  		baseParams = api.BaseParams{
   211  			Client: runParams.bp.Client,
   212  			URL:    proxyURL,
   213  			Method: http.MethodPut,
   214  			Token:  loggedUserToken,
   215  			UA:     ua,
   216  		}
   217  		args = api.PutArgs{
   218  			BaseParams: baseParams,
   219  			Bck:        bck,
   220  			ObjName:    objName,
   221  			Cksum:      cksum,
   222  			Reader:     reader,
   223  			SkipVC:     true,
   224  		}
   225  	)
   226  	_, err = api.PutObject(&args)
   227  	return
   228  }
   229  
   230  // PUT with HTTP trace
   231  func putWithTrace(proxyURL string, bck cmn.Bck, objName string, latencies *httpLatencies, cksum *cos.Cksum, reader cos.ReadOpenCloser) error {
   232  	reqArgs := cmn.AllocHra()
   233  	{
   234  		reqArgs.Method = http.MethodPut
   235  		reqArgs.Base = proxyURL
   236  		reqArgs.Path = apc.URLPathObjects.Join(bck.Name, objName)
   237  		reqArgs.Query = bck.NewQuery()
   238  		reqArgs.BodyR = reader
   239  	}
   240  	putter := tracePutter{
   241  		tctx:   newTraceCtx(proxyURL),
   242  		cksum:  cksum,
   243  		reader: reader,
   244  	}
   245  	_, err := api.DoWithRetry(putter.tctx.tracedClient, putter.do, reqArgs) //nolint:bodyclose // it's closed inside
   246  	cmn.FreeHra(reqArgs)
   247  	if err != nil {
   248  		return err
   249  	}
   250  	tctx := putter.tctx
   251  	tctx.tr.tsHTTPEnd = time.Now()
   252  
   253  	tctx.tr.set(latencies)
   254  	return nil
   255  }
   256  
   257  func newTraceCtx(proxyURL string) *traceCtx {
   258  	var (
   259  		tctx      = &traceCtx{}
   260  		transport = cmn.NewTransport(cargs)
   261  		err       error
   262  	)
   263  	if cos.IsHTTPS(proxyURL) {
   264  		transport.TLSClientConfig, err = cmn.NewTLS(sargs)
   265  		cos.AssertNoErr(err)
   266  	}
   267  	tctx.tr = &traceableTransport{
   268  		transport: transport,
   269  		tsBegin:   time.Now(),
   270  	}
   271  	tctx.trace = &httptrace.ClientTrace{
   272  		GotConn:              tctx.tr.GotConn,
   273  		WroteHeaders:         tctx.tr.WroteHeaders,
   274  		WroteRequest:         tctx.tr.WroteRequest,
   275  		GotFirstResponseByte: tctx.tr.GotFirstResponseByte,
   276  	}
   277  	tctx.tracedClient = &http.Client{
   278  		Transport: tctx.tr,
   279  		Timeout:   600 * time.Second,
   280  	}
   281  	return tctx
   282  }
   283  
   284  func newGetRequest(proxyURL string, bck cmn.Bck, objName string, offset, length int64, latest bool) (*http.Request, error) {
   285  	var (
   286  		hdr   http.Header
   287  		query = url.Values{}
   288  	)
   289  	query = bck.AddToQuery(query)
   290  	if etlName != "" {
   291  		query.Add(apc.QparamETLName, etlName)
   292  	}
   293  	if latest {
   294  		query.Add(apc.QparamLatestVer, "true")
   295  	}
   296  	if length > 0 {
   297  		rng := cmn.MakeRangeHdr(offset, length)
   298  		hdr = http.Header{cos.HdrRange: []string{rng}}
   299  	}
   300  	reqArgs := cmn.HreqArgs{
   301  		Method: http.MethodGet,
   302  		Base:   proxyURL,
   303  		Path:   apc.URLPathObjects.Join(bck.Name, objName),
   304  		Query:  query,
   305  		Header: hdr,
   306  	}
   307  	return reqArgs.Req()
   308  }
   309  
   310  func s3getDiscard(bck cmn.Bck, objName string) (int64, error) {
   311  	obj, err := s3svc.GetObject(context.Background(), &s3.GetObjectInput{
   312  		Bucket: aws.String(bck.Name),
   313  		Key:    aws.String(objName),
   314  	})
   315  	if err != nil {
   316  		if obj != nil && obj.Body != nil {
   317  			io.Copy(io.Discard, obj.Body)
   318  			obj.Body.Close()
   319  		}
   320  		return 0, err // detailed enough
   321  	}
   322  
   323  	var size, n int64
   324  	size = *obj.ContentLength
   325  	n, err = io.Copy(io.Discard, obj.Body)
   326  	obj.Body.Close()
   327  
   328  	if err != nil {
   329  		return n, fmt.Errorf("failed to GET %s/%s and discard it (%d, %d): %v", bck, objName, n, size, err)
   330  	}
   331  	if n != size {
   332  		err = fmt.Errorf("failed to GET %s/%s: wrong size (%d, %d)", bck, objName, n, size)
   333  	}
   334  	return size, err
   335  }
   336  
   337  // getDiscard sends a GET request and discards returned data.
   338  func getDiscard(proxyURL string, bck cmn.Bck, objName string, offset, length int64, validate, latest bool) (int64, error) {
   339  	req, err := newGetRequest(proxyURL, bck, objName, offset, length, latest)
   340  	if err != nil {
   341  		return 0, err
   342  	}
   343  	resp, err := runParams.bp.Client.Do(req)
   344  	if err != nil {
   345  		return 0, err
   346  	}
   347  
   348  	var hdrCksumValue, hdrCksumType string
   349  	if validate {
   350  		hdrCksumValue = resp.Header.Get(apc.HdrObjCksumVal)
   351  		hdrCksumType = resp.Header.Get(apc.HdrObjCksumType)
   352  	}
   353  	src := "GET " + bck.Cname(objName)
   354  	n, cksumValue, err := readDiscard(resp, src, hdrCksumType)
   355  
   356  	resp.Body.Close()
   357  	if err != nil {
   358  		return 0, err
   359  	}
   360  	if validate && hdrCksumValue != cksumValue {
   361  		return 0, cmn.NewErrInvalidCksum(hdrCksumValue, cksumValue)
   362  	}
   363  	return n, err
   364  }
   365  
   366  // Same as above, but with HTTP trace.
   367  func getTraceDiscard(proxyURL string, bck cmn.Bck, objName string, latencies *httpLatencies, offset, length int64, validate, latest bool) (int64, error) {
   368  	var (
   369  		hdrCksumValue string
   370  		hdrCksumType  string
   371  	)
   372  	req, err := newGetRequest(proxyURL, bck, objName, offset, length, latest)
   373  	if err != nil {
   374  		return 0, err
   375  	}
   376  
   377  	tctx := newTraceCtx(proxyURL)
   378  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), tctx.trace))
   379  
   380  	resp, err := tctx.tracedClient.Do(req)
   381  	if err != nil {
   382  		return 0, err
   383  	}
   384  	defer resp.Body.Close()
   385  
   386  	tctx.tr.tsHTTPEnd = time.Now()
   387  	if validate {
   388  		hdrCksumValue = resp.Header.Get(apc.HdrObjCksumVal)
   389  		hdrCksumType = resp.Header.Get(apc.HdrObjCksumType)
   390  	}
   391  
   392  	src := "GET " + bck.Cname(objName)
   393  	n, cksumValue, err := readDiscard(resp, src, hdrCksumType)
   394  	if err != nil {
   395  		return 0, err
   396  	}
   397  	if validate && hdrCksumValue != cksumValue {
   398  		err = cmn.NewErrInvalidCksum(hdrCksumValue, cksumValue)
   399  	}
   400  
   401  	tctx.tr.set(latencies)
   402  	return n, err
   403  }
   404  
   405  // getConfig sends a {what:config} request to the url and discard the message
   406  // For testing purpose only
   407  func getConfig(proxyURL string) (httpLatencies, error) {
   408  	tctx := newTraceCtx(proxyURL)
   409  
   410  	url := proxyURL + apc.URLPathDae.S
   411  	req, _ := http.NewRequest(http.MethodGet, url, http.NoBody)
   412  	req.URL.RawQuery = api.GetWhatRawQuery(apc.WhatNodeConfig, "")
   413  	req = req.WithContext(httptrace.WithClientTrace(req.Context(), tctx.trace))
   414  
   415  	resp, err := tctx.tracedClient.Do(req)
   416  	if err != nil {
   417  		return httpLatencies{}, err
   418  	}
   419  	defer resp.Body.Close()
   420  
   421  	_, _, err = readDiscard(resp, "GetConfig", "" /*cksum type*/)
   422  
   423  	l := httpLatencies{
   424  		ProxyConn: timeDelta(tctx.tr.tsProxyConn, tctx.tr.tsBegin),
   425  		Proxy:     time.Since(tctx.tr.tsProxyConn),
   426  	}
   427  	return l, err
   428  }
   429  
   430  func listObjCallback(ctx *api.LsoCounter) {
   431  	if ctx.Count() < 0 {
   432  		return
   433  	}
   434  	fmt.Printf("\rListing %s objects", cos.FormatBigNum(ctx.Count()))
   435  	if ctx.IsFinished() {
   436  		fmt.Println()
   437  	}
   438  }
   439  
   440  // listObjectNames returns a slice of object names of all objects that match the prefix in a bucket.
   441  func listObjectNames(baseParams api.BaseParams, bck cmn.Bck, prefix string, cached bool) ([]string, error) {
   442  	msg := &apc.LsoMsg{Prefix: prefix}
   443  	// if bck is remote then check for cached flag
   444  	if cached {
   445  		msg.Flags |= apc.LsObjCached
   446  	}
   447  	args := api.ListArgs{Callback: listObjCallback, CallAfter: longListTime}
   448  	objList, err := api.ListObjects(baseParams, bck, msg, args)
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  
   453  	objs := make([]string, 0, len(objList.Entries))
   454  	for _, obj := range objList.Entries {
   455  		objs = append(objs, obj.Name)
   456  	}
   457  	return objs, nil
   458  }
   459  
   460  func initS3Svc() error {
   461  	// '--s3profile' takes precedence
   462  	if s3Profile == "" {
   463  		if profile := os.Getenv(env.AWS.Profile); profile != "" {
   464  			s3Profile = profile
   465  		}
   466  	}
   467  	cfg, err := config.LoadDefaultConfig(
   468  		context.Background(),
   469  		config.WithSharedConfigProfile(s3Profile),
   470  	)
   471  	if err != nil {
   472  		return err
   473  	}
   474  	if s3Endpoint != "" {
   475  		cfg.BaseEndpoint = aws.String(s3Endpoint)
   476  	}
   477  	if cfg.Region == "" {
   478  		cfg.Region = env.AwsDefaultRegion()
   479  	}
   480  
   481  	s3svc = s3.NewFromConfig(cfg, func(o *s3.Options) {
   482  		o.UsePathStyle = s3UsePathStyle
   483  	})
   484  	return nil
   485  }
   486  
   487  func s3ListObjects() ([]string, error) {
   488  	// first page
   489  	params := &s3.ListObjectsV2Input{Bucket: aws.String(runParams.bck.Name)}
   490  	params.MaxKeys = aws.Int32(apc.MaxPageSizeAWS)
   491  
   492  	prev := mono.NanoTime()
   493  	resp, err := s3svc.ListObjectsV2(context.Background(), params)
   494  	if err != nil {
   495  		return nil, err
   496  	}
   497  
   498  	var (
   499  		token string
   500  		l     = len(resp.Contents)
   501  	)
   502  	if resp.NextContinuationToken != nil {
   503  		token = *resp.NextContinuationToken
   504  	}
   505  	if token != "" {
   506  		l = 16 * apc.MaxPageSizeAWS
   507  	}
   508  	names := make([]string, 0, l)
   509  	for _, object := range resp.Contents {
   510  		names = append(names, *object.Key)
   511  	}
   512  	if token == "" {
   513  		return names, nil
   514  	}
   515  
   516  	// get all the rest pages in one fell swoop
   517  	var eol bool
   518  	for token != "" {
   519  		params.ContinuationToken = &token
   520  		resp, err = s3svc.ListObjectsV2(context.Background(), params)
   521  		if err != nil {
   522  			return nil, err
   523  		}
   524  		for _, object := range resp.Contents {
   525  			names = append(names, *object.Key)
   526  		}
   527  		token = ""
   528  		if resp.NextContinuationToken != nil {
   529  			token = *resp.NextContinuationToken
   530  		}
   531  		now := mono.NanoTime()
   532  		if time.Duration(now-prev) >= longListTime {
   533  			fmt.Printf("\rListing %s objects", cos.FormatBigNum(len(names)))
   534  			prev = now
   535  			eol = true
   536  		}
   537  	}
   538  	if eol {
   539  		fmt.Println()
   540  	}
   541  	return names, nil
   542  }
   543  
   544  func readDiscard(r *http.Response, tag, cksumType string) (int64, string, error) {
   545  	var (
   546  		n          int64
   547  		cksum      *cos.CksumHash
   548  		err        error
   549  		cksumValue string
   550  	)
   551  	if r.StatusCode >= http.StatusBadRequest {
   552  		bytes, err := io.ReadAll(r.Body)
   553  		if err == nil {
   554  			return 0, "", fmt.Errorf("bad status %d from %s, response: %s", r.StatusCode, tag, string(bytes))
   555  		}
   556  		return 0, "", fmt.Errorf("bad status %d from %s: %v", r.StatusCode, tag, err)
   557  	}
   558  	n, cksum, err = cos.CopyAndChecksum(io.Discard, r.Body, nil, cksumType)
   559  	if err != nil {
   560  		return 0, "", fmt.Errorf("failed to read HTTP response, err: %v", err)
   561  	}
   562  	if cksum != nil {
   563  		cksumValue = cksum.Value()
   564  	}
   565  	return n, cksumValue, nil
   566  }
   567  
   568  func timeDelta(time1, time2 time.Time) time.Duration {
   569  	if time1.IsZero() || time2.IsZero() {
   570  		return 0
   571  	}
   572  	return time1.Sub(time2)
   573  }