github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ais/backend/awsmpt.go (about)

     1  //go:build aws
     2  
     3  // Package backend contains implementation of various backend providers.
     4  /*
     5   * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
     6   */
     7  package backend
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"encoding/xml"
    13  	"io"
    14  	"net/http"
    15  	"net/url"
    16  
    17  	aiss3 "github.com/NVIDIA/aistore/ais/s3"
    18  	"github.com/NVIDIA/aistore/cmn"
    19  	"github.com/NVIDIA/aistore/cmn/cos"
    20  	"github.com/NVIDIA/aistore/cmn/feat"
    21  	"github.com/NVIDIA/aistore/cmn/nlog"
    22  	"github.com/NVIDIA/aistore/core"
    23  	"github.com/aws/aws-sdk-go-v2/aws"
    24  	"github.com/aws/aws-sdk-go-v2/service/s3"
    25  	"github.com/aws/aws-sdk-go-v2/service/s3/types"
    26  )
    27  
    28  func decodeXML[T any](body []byte) (result T, _ error) {
    29  	if err := xml.Unmarshal(body, &result); err != nil {
    30  		return result, err
    31  	}
    32  	return result, nil
    33  }
    34  
    35  func StartMpt(lom *core.LOM, oreq *http.Request, oq url.Values) (id string, ecode int, _ error) {
    36  	if lom.IsFeatureSet(feat.S3PresignedRequest) && oreq != nil {
    37  		pts := aiss3.NewPresignedReq(oreq, lom, nil, oq)
    38  		resp, err := pts.Do(core.T.DataClient())
    39  		if err != nil {
    40  			return "", resp.StatusCode, err
    41  		}
    42  		if resp != nil {
    43  			result, err := decodeXML[aiss3.InitiateMptUploadResult](resp.Body)
    44  			if err != nil {
    45  				return "", http.StatusBadRequest, err
    46  			}
    47  			return result.UploadID, http.StatusOK, nil
    48  		}
    49  	}
    50  
    51  	var (
    52  		cloudBck = lom.Bck().RemoteBck()
    53  		sessConf = sessConf{bck: cloudBck}
    54  		input    = s3.CreateMultipartUploadInput{
    55  			Bucket: aws.String(cloudBck.Name),
    56  			Key:    aws.String(lom.ObjName),
    57  		}
    58  	)
    59  	svc, errN := sessConf.s3client("[start_mpt]")
    60  	if errN != nil && cmn.Rom.FastV(5, cos.SmoduleBackend) {
    61  		nlog.Warningln(errN)
    62  	}
    63  	out, err := svc.CreateMultipartUpload(context.Background(), &input)
    64  	if err == nil {
    65  		id = *out.UploadId
    66  	} else {
    67  		ecode, err = awsErrorToAISError(err, cloudBck, lom.ObjName)
    68  	}
    69  	return id, ecode, err
    70  }
    71  
    72  func PutMptPart(lom *core.LOM, r io.ReadCloser, oreq *http.Request, oq url.Values, uploadID string, size int64, partNum int32) (etag string,
    73  	ecode int, _ error) {
    74  	if lom.IsFeatureSet(feat.S3PresignedRequest) && oreq != nil {
    75  		pts := aiss3.NewPresignedReq(oreq, lom, r, oq)
    76  		resp, err := pts.Do(core.T.DataClient())
    77  		if err != nil {
    78  			return "", resp.StatusCode, err
    79  		}
    80  		if resp != nil {
    81  			ecode = resp.StatusCode
    82  			etag = cmn.UnquoteCEV(resp.Header.Get(cos.HdrETag))
    83  			return
    84  		}
    85  	}
    86  
    87  	var (
    88  		cloudBck = lom.Bck().RemoteBck()
    89  		sessConf = sessConf{bck: cloudBck}
    90  		input    = s3.UploadPartInput{
    91  			Bucket:        aws.String(cloudBck.Name),
    92  			Key:           aws.String(lom.ObjName),
    93  			Body:          r,
    94  			UploadId:      aws.String(uploadID),
    95  			PartNumber:    &partNum,
    96  			ContentLength: &size,
    97  		}
    98  	)
    99  	svc, errN := sessConf.s3client("[put_mpt_part]")
   100  	if errN != nil && cmn.Rom.FastV(5, cos.SmoduleBackend) {
   101  		nlog.Warningln(errN)
   102  	}
   103  
   104  	out, err := svc.UploadPart(context.Background(), &input)
   105  	if err != nil {
   106  		ecode, err = awsErrorToAISError(err, cloudBck, lom.ObjName)
   107  	} else {
   108  		etag = cmn.UnquoteCEV(*out.ETag)
   109  	}
   110  
   111  	return etag, ecode, err
   112  }
   113  
   114  func CompleteMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string, parts *aiss3.CompleteMptUpload) (etag string,
   115  	ecode int, _ error) {
   116  	if lom.IsFeatureSet(feat.S3PresignedRequest) && oreq != nil {
   117  		body, err := xml.Marshal(parts)
   118  		if err != nil {
   119  			return "", http.StatusBadRequest, err
   120  		}
   121  		pts := aiss3.NewPresignedReq(oreq, lom, io.NopCloser(bytes.NewReader(body)), oq)
   122  		resp, err := pts.Do(core.T.DataClient())
   123  		if err != nil {
   124  			return "", resp.StatusCode, err
   125  		}
   126  		if resp != nil {
   127  			result, err := decodeXML[aiss3.CompleteMptUploadResult](resp.Body)
   128  			if err != nil {
   129  				return "", http.StatusBadRequest, err
   130  			}
   131  			etag = result.ETag
   132  			return
   133  		}
   134  	}
   135  
   136  	var (
   137  		cloudBck = lom.Bck().RemoteBck()
   138  		sessConf = sessConf{bck: cloudBck}
   139  		s3parts  types.CompletedMultipartUpload
   140  		input    = s3.CompleteMultipartUploadInput{
   141  			Bucket:   aws.String(cloudBck.Name),
   142  			Key:      aws.String(lom.ObjName),
   143  			UploadId: aws.String(uploadID),
   144  		}
   145  	)
   146  	svc, errN := sessConf.s3client("[complete_mpt]")
   147  	if errN != nil && cmn.Rom.FastV(5, cos.SmoduleBackend) {
   148  		nlog.Warningln(errN)
   149  	}
   150  
   151  	// TODO -- FIXME: reduce copying
   152  	s3parts.Parts = make([]types.CompletedPart, 0, len(parts.Parts))
   153  	for _, part := range parts.Parts {
   154  		s3parts.Parts = append(s3parts.Parts, types.CompletedPart{
   155  			ETag:       aws.String(part.ETag),
   156  			PartNumber: aws.Int32(part.PartNumber),
   157  		})
   158  	}
   159  	input.MultipartUpload = &s3parts
   160  
   161  	out, err := svc.CompleteMultipartUpload(context.Background(), &input)
   162  	if err != nil {
   163  		ecode, err = awsErrorToAISError(err, cloudBck, lom.ObjName)
   164  	} else {
   165  		etag = cmn.UnquoteCEV(*out.ETag)
   166  	}
   167  
   168  	return etag, ecode, err
   169  }
   170  
   171  func AbortMpt(lom *core.LOM, oreq *http.Request, oq url.Values, uploadID string) (ecode int, err error) {
   172  	if lom.IsFeatureSet(feat.S3PresignedRequest) && oreq != nil {
   173  		pts := aiss3.NewPresignedReq(oreq, lom, oreq.Body, oq)
   174  		resp, err := pts.Do(core.T.DataClient())
   175  		if err != nil {
   176  			return resp.StatusCode, err
   177  		}
   178  		if resp != nil {
   179  			return resp.StatusCode, nil
   180  		}
   181  	}
   182  
   183  	var (
   184  		cloudBck = lom.Bck().RemoteBck()
   185  		sessConf = sessConf{bck: cloudBck}
   186  		input    = s3.AbortMultipartUploadInput{
   187  			Bucket:   aws.String(cloudBck.Name),
   188  			Key:      aws.String(lom.ObjName),
   189  			UploadId: aws.String(uploadID),
   190  		}
   191  	)
   192  	svc, errN := sessConf.s3client("[abort_mpt]")
   193  	if errN != nil && cmn.Rom.FastV(5, cos.SmoduleBackend) {
   194  		nlog.Warningln(errN)
   195  	}
   196  	if _, err = svc.AbortMultipartUpload(context.Background(), &input); err != nil {
   197  		ecode, err = awsErrorToAISError(err, cloudBck, lom.ObjName)
   198  	}
   199  	return ecode, err
   200  }