github.com/snowflakedb/gosnowflake@v1.9.0/s3_storage_client.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "os" 13 "strings" 14 15 "github.com/aws/aws-sdk-go-v2/aws" 16 "github.com/aws/aws-sdk-go-v2/credentials" 17 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 18 "github.com/aws/aws-sdk-go-v2/service/s3" 19 "github.com/aws/smithy-go" 20 ) 21 22 const ( 23 sfcDigest = "sfc-digest" 24 amzMatdesc = "x-amz-matdesc" 25 amzKey = "x-amz-key" 26 amzIv = "x-amz-iv" 27 28 notFound = "NotFound" 29 expiredToken = "ExpiredToken" 30 errNoWsaeconnaborted = "10053" 31 ) 32 33 type snowflakeS3Client struct { 34 } 35 36 type s3Location struct { 37 bucketName string 38 s3Path string 39 } 40 41 func (util *snowflakeS3Client) createClient(info *execResponseStageInfo, useAccelerateEndpoint bool) (cloudClient, error) { 42 stageCredentials := info.Creds 43 var resolver s3.EndpointResolver 44 if info.EndPoint != "" { 45 resolver = s3.EndpointResolverFromURL("https://" + info.EndPoint) // FIPS endpoint 46 } 47 48 return s3.New(s3.Options{ 49 Region: info.Region, 50 Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider( 51 stageCredentials.AwsKeyID, 52 stageCredentials.AwsSecretKey, 53 stageCredentials.AwsToken)), 54 EndpointResolver: resolver, 55 UseAccelerate: useAccelerateEndpoint, 56 HTTPClient: &http.Client{ 57 Transport: SnowflakeTransport, 58 }, 59 }), nil 60 } 61 62 type s3HeaderAPI interface { 63 HeadObject(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) 64 } 65 66 // cloudUtil implementation 67 func (util *snowflakeS3Client) getFileHeader(meta *fileMetadata, filename string) (*fileHeader, error) { 68 headObjInput, err := util.getS3Object(meta, filename) 69 if err != nil { 70 return nil, err 71 } 72 var s3Cli s3HeaderAPI 73 s3Cli, ok := meta.client.(*s3.Client) 74 if !ok { 75 return nil, fmt.Errorf("could not parse client to s3.Client") 76 } 77 // for testing only 78 if meta.mockHeader != nil { 79 s3Cli = meta.mockHeader 80 } 81 out, err := s3Cli.HeadObject(context.Background(), headObjInput) 82 if err != nil { 83 var ae smithy.APIError 84 if errors.As(err, &ae) { 85 if ae.ErrorCode() == notFound { 86 meta.resStatus = notFoundFile 87 return nil, fmt.Errorf("could not find file") 88 } else if ae.ErrorCode() == expiredToken { 89 meta.resStatus = renewToken 90 return nil, fmt.Errorf("received expired token. renewing") 91 } 92 meta.resStatus = errStatus 93 meta.lastError = err 94 return nil, fmt.Errorf("error while retrieving header") 95 } 96 meta.resStatus = errStatus 97 meta.lastError = err 98 return nil, fmt.Errorf("unexpected error while retrieving header: %v", err) 99 } 100 101 meta.resStatus = uploaded 102 var encMeta encryptMetadata 103 if out.Metadata[amzKey] != "" { 104 encMeta = encryptMetadata{ 105 out.Metadata[amzKey], 106 out.Metadata[amzIv], 107 out.Metadata[amzMatdesc], 108 } 109 } 110 contentLength := convertContentLength(out.ContentLength) 111 return &fileHeader{ 112 out.Metadata[sfcDigest], 113 contentLength, 114 &encMeta, 115 }, nil 116 } 117 118 // SNOW-974548 remove this function after upgrading AWS SDK 119 func convertContentLength(contentLength any) int64 { 120 switch t := contentLength.(type) { 121 case int64: 122 return t 123 case *int64: 124 if t != nil { 125 return *t 126 } 127 } 128 return 0 129 } 130 131 type s3UploadAPI interface { 132 Upload(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) 133 } 134 135 // cloudUtil implementation 136 func (util *snowflakeS3Client) uploadFile( 137 dataFile string, 138 meta *fileMetadata, 139 encryptMeta *encryptMetadata, 140 maxConcurrency int, 141 multiPartThreshold int64) error { 142 s3Meta := map[string]string{ 143 httpHeaderContentType: httpHeaderValueOctetStream, 144 sfcDigest: meta.sha256Digest, 145 } 146 if encryptMeta != nil { 147 s3Meta[amzIv] = encryptMeta.iv 148 s3Meta[amzKey] = encryptMeta.key 149 s3Meta[amzMatdesc] = encryptMeta.matdesc 150 } 151 152 s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location) 153 if err != nil { 154 return err 155 } 156 s3path := s3loc.s3Path + strings.TrimLeft(meta.dstFileName, "/") 157 158 client, ok := meta.client.(*s3.Client) 159 if !ok { 160 return &SnowflakeError{ 161 Message: "failed to cast to s3 client", 162 } 163 } 164 var uploader s3UploadAPI 165 uploader = manager.NewUploader(client, func(u *manager.Uploader) { 166 u.Concurrency = maxConcurrency 167 u.PartSize = int64Max(multiPartThreshold, manager.DefaultUploadPartSize) 168 }) 169 // for testing only 170 if meta.mockUploader != nil { 171 uploader = meta.mockUploader 172 } 173 174 if meta.srcStream != nil { 175 uploadStream := meta.srcStream 176 if meta.realSrcStream != nil { 177 uploadStream = meta.realSrcStream 178 } 179 _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 180 Bucket: &s3loc.bucketName, 181 Key: &s3path, 182 Body: bytes.NewBuffer(uploadStream.Bytes()), 183 Metadata: s3Meta, 184 }) 185 } else { 186 var file *os.File 187 file, err = os.Open(dataFile) 188 if err != nil { 189 return err 190 } 191 _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 192 Bucket: &s3loc.bucketName, 193 Key: &s3path, 194 Body: file, 195 Metadata: s3Meta, 196 }) 197 } 198 199 if err != nil { 200 var ae smithy.APIError 201 if errors.As(err, &ae) { 202 if ae.ErrorCode() == expiredToken { 203 meta.resStatus = renewToken 204 return err 205 } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { 206 meta.lastError = err 207 meta.resStatus = needRetryWithLowerConcurrency 208 return err 209 } 210 } 211 meta.lastError = err 212 meta.resStatus = needRetry 213 return err 214 } 215 meta.dstFileSize = meta.uploadSize 216 meta.resStatus = uploaded 217 return nil 218 } 219 220 type s3DownloadAPI interface { 221 Download(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) 222 } 223 224 // cloudUtil implementation 225 func (util *snowflakeS3Client) nativeDownloadFile( 226 meta *fileMetadata, 227 fullDstFileName string, 228 maxConcurrency int64) error { 229 s3Obj, _ := util.getS3Object(meta, meta.srcFileName) 230 client, ok := meta.client.(*s3.Client) 231 if !ok { 232 return &SnowflakeError{ 233 Message: "failed to cast to s3 client", 234 } 235 } 236 237 f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) 238 if err != nil { 239 return err 240 } 241 defer f.Close() 242 var downloader s3DownloadAPI 243 downloader = manager.NewDownloader(client, func(u *manager.Downloader) { 244 u.Concurrency = int(maxConcurrency) 245 }) 246 // for testing only 247 if meta.mockDownloader != nil { 248 downloader = meta.mockDownloader 249 } 250 if _, err = downloader.Download(context.Background(), f, &s3.GetObjectInput{ 251 Bucket: s3Obj.Bucket, 252 Key: s3Obj.Key, 253 }); err != nil { 254 var ae smithy.APIError 255 if errors.As(err, &ae) { 256 if ae.ErrorCode() == expiredToken { 257 meta.resStatus = renewToken 258 return err 259 } else if strings.Contains(ae.ErrorCode(), errNoWsaeconnaborted) { 260 meta.lastError = err 261 meta.resStatus = needRetryWithLowerConcurrency 262 return err 263 } 264 meta.lastError = err 265 meta.resStatus = errStatus 266 return err 267 } 268 meta.lastError = err 269 meta.resStatus = needRetry 270 return err 271 } 272 meta.resStatus = downloaded 273 return nil 274 } 275 276 func (util *snowflakeS3Client) extractBucketNameAndPath(location string) (*s3Location, error) { 277 stageLocation, err := expandUser(location) 278 if err != nil { 279 return nil, err 280 } 281 bucketName := stageLocation 282 s3Path := "" 283 284 if idx := strings.Index(stageLocation, "/"); idx >= 0 { 285 bucketName = stageLocation[0:idx] 286 s3Path = stageLocation[idx+1:] 287 if s3Path != "" && !strings.HasSuffix(s3Path, "/") { 288 s3Path += "/" 289 } 290 } 291 return &s3Location{bucketName, s3Path}, nil 292 } 293 294 func (util *snowflakeS3Client) getS3Object(meta *fileMetadata, filename string) (*s3.HeadObjectInput, error) { 295 s3loc, err := util.extractBucketNameAndPath(meta.stageInfo.Location) 296 if err != nil { 297 return nil, err 298 } 299 s3path := s3loc.s3Path + strings.TrimLeft(filename, "/") 300 return &s3.HeadObjectInput{ 301 Bucket: &s3loc.bucketName, 302 Key: &s3path, 303 }, nil 304 }