github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ais/s3/presigned.go (about)

     1  // Package s3 provides Amazon S3 compatibility layer
     2  /*
     3   * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
     4   */
     5  package s3
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/NVIDIA/aistore/api/apc"
    16  	"github.com/NVIDIA/aistore/cmn"
    17  	"github.com/NVIDIA/aistore/cmn/cos"
    18  	"github.com/NVIDIA/aistore/cmn/debug"
    19  	"github.com/NVIDIA/aistore/core"
    20  )
    21  
    22  const (
    23  	signatureV4 = "AWS4-HMAC-SHA256"
    24  )
    25  
    26  type (
    27  	PresignedReq struct {
    28  		oreq  *http.Request
    29  		lom   *core.LOM
    30  		body  io.ReadCloser
    31  		query url.Values
    32  	}
    33  	PresignedResp struct {
    34  		Body       []byte
    35  		BodyR      io.ReadCloser // Set when invoked `Do` with `async` option.
    36  		Size       int64
    37  		Header     http.Header
    38  		StatusCode int
    39  	}
    40  )
    41  
    42  //////////////////
    43  // PresignedReq //
    44  //////////////////
    45  
    46  func NewPresignedReq(oreq *http.Request, lom *core.LOM, body io.ReadCloser, q url.Values) *PresignedReq {
    47  	return &PresignedReq{oreq, lom, body, q}
    48  }
    49  
    50  // FIXME: handle error cases
    51  func parseSignatureV4(query url.Values, header http.Header) (region string) {
    52  	if credentials := query.Get(HeaderCredentials); credentials != "" {
    53  		region = strings.Split(credentials, "/")[2]
    54  	} else if credentials := header.Get(apc.HdrAuthorization); strings.HasPrefix(credentials, signatureV4) {
    55  		credentials = strings.TrimPrefix(credentials, signatureV4)
    56  		credentials = strings.TrimSpace(credentials)
    57  		credentials = strings.Split(credentials, ", ")[0]
    58  		credentials = strings.TrimPrefix(credentials, "Credential=")
    59  		region = strings.Split(credentials, "/")[2]
    60  	}
    61  	return region
    62  }
    63  
    64  func (pts *PresignedReq) Do(client *http.Client) (*PresignedResp, error) {
    65  	resp, err := pts.DoReader(client)
    66  	if err != nil {
    67  		return resp, err
    68  	} else if resp == nil {
    69  		return nil, nil
    70  	}
    71  	defer resp.BodyR.Close()
    72  
    73  	output, err := io.ReadAll(resp.BodyR)
    74  	if err != nil {
    75  		return &PresignedResp{StatusCode: http.StatusBadRequest}, fmt.Errorf("failed to read response body: %v", err)
    76  	}
    77  	return &PresignedResp{Body: output, Size: int64(len(output)), Header: resp.Header, StatusCode: resp.StatusCode}, nil
    78  }
    79  
    80  // DoReader sends request and returns opened body/reader if successful.
    81  // Caller is responsible for closing the reader.
    82  func (pts *PresignedReq) DoReader(client *http.Client) (*PresignedResp, error) {
    83  	region := parseSignatureV4(pts.query, pts.oreq.Header)
    84  	if region == "" {
    85  		return nil, nil
    86  	}
    87  
    88  	// S3 checks every single query param
    89  	pts.query.Del(apc.QparamProxyID)
    90  	pts.query.Del(apc.QparamUnixTime)
    91  	queryEncoded := pts.query.Encode()
    92  
    93  	// produce a new request (nreq) from the old/original one (oreq)
    94  	s3url := makeS3URL(region, pts.lom.Bck().Name, pts.lom.ObjName, queryEncoded)
    95  	nreq, err := http.NewRequest(pts.oreq.Method, s3url, pts.body)
    96  	if err != nil {
    97  		return &PresignedResp{StatusCode: http.StatusInternalServerError}, err
    98  	}
    99  	nreq.Header = pts.oreq.Header // NOTE: _not_ cloning
   100  	if nreq.Body != nil {
   101  		nreq.ContentLength = pts.oreq.ContentLength
   102  		if nreq.ContentLength == -1 {
   103  			debug.Assert(false) // FIXME: remove, or catch in debug mode
   104  			nreq.ContentLength = pts.lom.SizeBytes()
   105  		}
   106  	}
   107  
   108  	resp, err := client.Do(nreq)
   109  	if err != nil {
   110  		return &PresignedResp{StatusCode: http.StatusInternalServerError}, err
   111  	}
   112  
   113  	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
   114  		output, _ := io.ReadAll(resp.Body)
   115  		resp.Body.Close()
   116  		return &PresignedResp{StatusCode: resp.StatusCode}, fmt.Errorf("invalid status: %d, output: %s", resp.StatusCode, string(output))
   117  	}
   118  
   119  	return &PresignedResp{BodyR: resp.Body, Size: resp.ContentLength, Header: resp.Header, StatusCode: resp.StatusCode}, nil
   120  }
   121  
   122  ///////////////////
   123  // PresignedResp //
   124  ///////////////////
   125  
   126  // (compare w/ cmn/objattrs FromHeader)
   127  func (resp *PresignedResp) ObjAttrs() (oa *cmn.ObjAttrs) {
   128  	oa = &cmn.ObjAttrs{}
   129  	oa.CustomMD = make(cos.StrKVs, 3)
   130  
   131  	oa.SetCustomKey(cmn.SourceObjMD, apc.AWS)
   132  	etag := cmn.UnquoteCEV(resp.Header.Get(cos.HdrETag))
   133  	debug.Assert(etag != "")
   134  	oa.SetCustomKey(cmn.ETag, etag)
   135  	if !cmn.IsS3MultipartEtag(etag) {
   136  		oa.SetCustomKey(cmn.MD5ObjMD, etag)
   137  	}
   138  	if sz := resp.Header.Get(cos.HdrContentLength); sz != "" {
   139  		size, err := strconv.ParseInt(sz, 10, 64)
   140  		debug.AssertNoErr(err)
   141  		oa.Size = size
   142  	}
   143  	return oa
   144  }