github.com/snowflakedb/gosnowflake@v1.9.0/s3_storage_client.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"os"
    13  	"strings"
    14  
    15  	"github.com/aws/aws-sdk-go-v2/aws"
    16  	"github.com/aws/aws-sdk-go-v2/credentials"
    17  	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
    18  	"github.com/aws/aws-sdk-go-v2/service/s3"
    19  	"github.com/aws/smithy-go"
    20  )
    21  
    22  const (
    23  	sfcDigest  = "sfc-digest"
    24  	amzMatdesc = "x-amz-matdesc"
    25  	amzKey     = "x-amz-key"
    26  	amzIv      = "x-amz-iv"
    27  
    28  	notFound             = "NotFound"
    29  	expiredToken         = "ExpiredToken"
    30  	errNoWsaeconnaborted = "10053"
    31  )
    32  
    33  type snowflakeS3Client struct {
    34  }
    35  
    36  type s3Location struct {
    37  	bucketName string
    38  	s3Path     string
    39  }
    40  
    41  func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) {
    42  	stageCredentials := info.Creds
    43  	var resolver s3.EndpointResolver
    44  	if info.EndPoint != "" {
    45  		resolver = s3.EndpointResolverFromURL("https://" + info.EndPoint) // FIPS endpoint
    46  	}
    47  
    48  	return s3.New(s3.Options{
    49  		Region: info.Region,
    50  		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(
    51  			stageCredentials.AwsKeyID,
    52  			stageCredentials.AwsSecretKey,
    53  			stageCredentials.AwsToken)),
    54  		EndpointResolver: resolver,
    55  		UseAccelerate:    useAccelerateEndpoint,
    56  		HTTPClient: &http.Client{
    57  			Transport: SnowflakeTransport,
    58  		},
    59  	}), nil
    60  }
    61  
    62  type s3HeaderAPI interface {
    63  	HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
    64  }
    65  
    66  // cloudUtil implementation
    67  func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string) (*fileHeader, error) {
    68  	headObjInput, err := util.getS3Object(meta, filename)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	var s3Cli s3HeaderAPI
    73  	s3Cli, ok := meta.client.(*s3.Client)
    74  	if !ok {
    75  		return nil, fmt.Errorf("could not parse client to s3.Client")
    76  	}
    77  	// for testing only
    78  	if meta.mockHeader != nil {
    79  		s3Cli = meta.mockHeader
    80  	}
    81  	out, err := s3Cli.HeadObject(context.Background(), headObjInput)
    82  	if err != nil {
    83  		var ae smithy.APIError
    84  		if errors.As(err, &ae) {
    85  			if ae.ErrorCode() == notFound {
    86  				meta.resStatus = notFoundFile
    87  				return nil, fmt.Errorf("could not find file")
    88  			} else if ae.ErrorCode() == expiredToken {
    89  				meta.resStatus = renewToken
    90  				return nil, fmt.Errorf("received expired token. renewing")
    91  			}
    92  			meta.resStatus = errStatus
    93  			meta.lastError = err
    94  			return nil, fmt.Errorf("error while retrieving header")
    95  		}
    96  		meta.resStatus = errStatus
    97  		meta.lastError = err
    98  		return nil, fmt.Errorf("unexpected error while retrieving header: %v", err)
    99  	}
   100  
   101  	meta.resStatus = uploaded
   102  	var encMeta encryptMetadata
   103  	if out.Metadata[amzKey] != "" {
   104  		encMeta = encryptMetadata{
   105  			out.Metadata[amzKey],
   106  			out.Metadata[amzIv],
   107  			out.Metadata[amzMatdesc],
   108  		}
   109  	}
   110  	contentLength := convertContentLength(out.ContentLength)
   111  	return &fileHeader{
   112  		out.Metadata[sfcDigest],
   113  		contentLength,
   114  		&encMeta,
   115  	}, nil
   116  }
   117  
   118  // SNOW-974548 remove this function after upgrading AWS SDK
   119  func convertContentLength(contentLength any) int64 {
   120  	switch t := contentLength.(type) {
   121  	case int64:
   122  		return t
   123  	case *int64:
   124  		if t != nil {
   125  			return *t
   126  		}
   127  	}
   128  	return 0
   129  }
   130  
   131  type s3UploadAPI interface {
   132  	Upload(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error)
   133  }
   134  
   135  // cloudUtil implementation
   136  func (util *snowflakeS3Client) uploadFile(
   137  	dataFile string,
   138  	meta *fileMetadata,
   139  	encryptMeta *encryptMetadata,
   140  	maxConcurrency int,
   141  	multiPartThreshold int64) error {
   142  	s3Meta := map[string]string{
   143  		httpHeaderContentType: httpHeaderValueOctetStream,
   144  		sfcDigest:             meta.sha256Digest,
   145  	}
   146  	if encryptMeta != nil {
   147  		s3Meta[amzIv] = encryptMeta.iv
   148  		s3Meta[amzKey] = encryptMeta.key
   149  		s3Meta[amzMatdesc] = encryptMeta.matdesc
   150  	}
   151  
   152  	s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/")
   157  
   158  	client, ok := meta.client.(*s3.Client)
   159  	if !ok {
   160  		return &SnowflakeError{
   161  			Message: "failed to cast to s3 client",
   162  		}
   163  	}
   164  	var uploader s3UploadAPI
   165  	uploader = manager.NewUploader(client, func(u *manager.Uploader) {
   166  		u.Concurrency = maxConcurrency
   167  		u.PartSize = int64Max(multiPartThreshold, manager.DefaultUploadPartSize)
   168  	})
   169  	// for testing only
   170  	if meta.mockUploader != nil {
   171  		uploader = meta.mockUploader
   172  	}
   173  
   174  	if meta.srcStream != nil {
   175  		uploadStream := meta.srcStream
   176  		if meta.realSrcStream != nil {
   177  			uploadStream = meta.realSrcStream
   178  		}
   179  		_, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   180  			Bucket:   &s3loc.bucketName,
   181  			Key:      &s3path,
   182  			Body:     bytes.NewBuffer(uploadStream.Bytes()),
   183  			Metadata: s3Meta,
   184  		})
   185  	} else {
   186  		var file *os.File
   187  		file, err = os.Open(dataFile)
   188  		if err != nil {
   189  			return err
   190  		}
   191  		_, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   192  			Bucket:   &s3loc.bucketName,
   193  			Key:      &s3path,
   194  			Body:     file,
   195  			Metadata: s3Meta,
   196  		})
   197  	}
   198  
   199  	if err != nil {
   200  		var ae smithy.APIError
   201  		if errors.As(err, &ae) {
   202  			if ae.ErrorCode() == expiredToken {
   203  				meta.resStatus = renewToken
   204  				return err
   205  			} else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) {
   206  				meta.lastError = err
   207  				meta.resStatus = needRetryWithLowerConcurrency
   208  				return err
   209  			}
   210  		}
   211  		meta.lastError = err
   212  		meta.resStatus = needRetry
   213  		return err
   214  	}
   215  	meta.dstFileSize = meta.uploadSize
   216  	meta.resStatus = uploaded
   217  	return nil
   218  }
   219  
   220  type s3DownloadAPI interface {
   221  	Download(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error)
   222  }
   223  
   224  // cloudUtil implementation
   225  func (util *snowflakeS3Client) nativeDownloadFile(
   226  	meta *fileMetadata,
   227  	fullDstFileName string,
   228  	maxConcurrency int64) error {
   229  	s3Obj, _ := util.getS3Object(meta, meta.srcFileName)
   230  	client, ok := meta.client.(*s3.Client)
   231  	if !ok {
   232  		return &SnowflakeError{
   233  			Message: "failed to cast to s3 client",
   234  		}
   235  	}
   236  
   237  	f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	defer f.Close()
   242  	var downloader s3DownloadAPI
   243  	downloader = manager.NewDownloader(client, func(u *manager.Downloader) {
   244  		u.Concurrency = int(maxConcurrency)
   245  	})
   246  	// for testing only
   247  	if meta.mockDownloader != nil {
   248  		downloader = meta.mockDownloader
   249  	}
   250  	if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{
   251  		Bucket: s3Obj.Bucket,
   252  		Key:    s3Obj.Key,
   253  	}); err != nil {
   254  		var ae smithy.APIError
   255  		if errors.As(err, &ae) {
   256  			if ae.ErrorCode() == expiredToken {
   257  				meta.resStatus = renewToken
   258  				return err
   259  			} else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) {
   260  				meta.lastError = err
   261  				meta.resStatus = needRetryWithLowerConcurrency
   262  				return err
   263  			}
   264  			meta.lastError = err
   265  			meta.resStatus = errStatus
   266  			return err
   267  		}
   268  		meta.lastError = err
   269  		meta.resStatus = needRetry
   270  		return err
   271  	}
   272  	meta.resStatus = downloaded
   273  	return nil
   274  }
   275  
   276  func (util *snowflakeS3Client) extractBucketNameAndPath(location string) (*s3Location, error) {
   277  	stageLocation, err := expandUser(location)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  	bucketName := stageLocation
   282  	s3Path := ""
   283  
   284  	if idx := strings.Index(stageLocation, "/"); idx >= 0 {
   285  		bucketName = stageLocation[0:idx]
   286  		s3Path = stageLocation[idx+1:]
   287  		if s3Path != "" && !strings.HasSuffix(s3Path, "/") {
   288  			s3Path += "/"
   289  		}
   290  	}
   291  	return &s3Location{bucketName, s3Path}, nil
   292  }
   293  
   294  func (util *snowflakeS3Client) getS3Object(meta *fileMetadata, filename string) (*s3.HeadObjectInput, error) {
   295  	s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	s3path := s3loc.s3Path + strings.TrimLeft(filename, "/")
   300  	return &s3.HeadObjectInput{
   301  		Bucket: &s3loc.bucketName,
   302  		Key:    &s3path,
   303  	}, nil
   304  }