github.com/snowflakedb/gosnowflake@v1.9.0/file_transfer_agent.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 //lint:file-ignore U1000 Ignore all unused code 6 7 import ( 8 "bytes" 9 "context" 10 "database/sql/driver" 11 "encoding/json" 12 "errors" 13 "fmt" 14 "io" 15 "math" 16 "net/url" 17 "os" 18 "path/filepath" 19 "regexp" 20 "runtime" 21 "sort" 22 "strings" 23 "sync" 24 "time" 25 26 "github.com/aws/aws-sdk-go-v2/service/s3" 27 "github.com/aws/smithy-go" 28 "github.com/gabriel-vasile/mimetype" 29 ) 30 31 type ( 32 cloudType string 33 commandType string 34 ) 35 36 const ( 37 fileProtocol = "file://" 38 dataSizeThreshold int64 = 64 * 1024 * 1024 39 injectWaitPut = 0 40 isWindows = runtime.GOOS == "windows" 41 mb float64 = 1024.0 * 1024.0 42 ) 43 44 const ( 45 uploadCommand commandType = "UPLOAD" 46 downloadCommand commandType = "DOWNLOAD" 47 unknownCommand commandType = "UNKNOWN" 48 49 putRegexp string = `(?i)^(?:/\*.*\*/\s*)*put\s+` 50 getRegexp string = `(?i)^(?:/\*.*\*/\s*)*get\s+` 51 ) 52 53 const ( 54 s3Client cloudType = "S3" 55 azureClient cloudType = "AZURE" 56 gcsClient cloudType = "GCS" 57 local cloudType = "LOCAL_FS" 58 ) 59 60 type resultStatus int 61 62 const ( 63 errStatus resultStatus = iota 64 uploaded 65 downloaded 66 skipped 67 renewToken 68 renewPresignedURL 69 notFoundFile 70 needRetry 71 needRetryWithLowerConcurrency 72 ) 73 74 func (rs resultStatus) String() string { 75 return [...]string{"ERROR", "UPLOADED", "DOWNLOADED", "SKIPPED", 76 "RENEW_TOKEN", "RENEW_PRESIGNED_URL", "NOT_FOUND_FILE", "NEED_RETRY", 77 "NEED_RETRY_WITH_LOWER_CONCURRENCY"}[rs] 78 } 79 80 func (rs resultStatus) isSet() bool { 81 return uploaded <= rs && rs <= needRetryWithLowerConcurrency 82 } 83 84 // SnowflakeFileTransferOptions enables users to specify options regarding 85 // files transfers such as PUT/GET 86 type SnowflakeFileTransferOptions struct { 87 showProgressBar bool 88 RaisePutGetError bool 89 MultiPartThreshold int64 90 91 /* streaming PUT */ 92 compressSourceFromStream bool 93 94 /* PUT */ 95 putCallback *snowflakeProgressPercentage 96 putAzureCallback *snowflakeProgressPercentage 97 putCallbackOutputStream *io.Writer 98 99 /* GET */ 100 getCallback *snowflakeProgressPercentage 101 getAzureCallback *snowflakeProgressPercentage 102 getCallbackOutputStream *io.Writer 103 } 104 105 type snowflakeFileTransferAgent struct { 106 sc *snowflakeConn 107 data *execResponseData 108 command string 109 commandType commandType 110 stageLocationType cloudType 111 fileMetadata []*fileMetadata 112 encryptionMaterial []*snowflakeFileEncryption 113 stageInfo *execResponseStageInfo 114 results []*fileMetadata 115 sourceStream *bytes.Buffer 116 srcLocations []string 117 autoCompress bool 118 srcCompression string 119 parallel int64 120 overwrite bool 121 srcFiles []string 122 localLocation string 123 srcFileToEncryptionMaterial map[string]*snowflakeFileEncryption 124 useAccelerateEndpoint bool 125 presignedURLs []string 126 options *SnowflakeFileTransferOptions 127 } 128 129 func (sfa *snowflakeFileTransferAgent) execute() error { 130 var err error 131 if err = sfa.parseCommand(); err != nil { 132 return err 133 } 134 if err = sfa.initFileMetadata(); err != nil { 135 return err 136 } 137 138 if sfa.commandType == uploadCommand { 139 if err = sfa.processFileCompressionType(); err != nil { 140 return err 141 } 142 } 143 144 if err = sfa.transferAccelerateConfig(); err != nil { 145 return err 146 } 147 148 if sfa.commandType == downloadCommand { 149 if _, err = os.Stat(sfa.localLocation); os.IsNotExist(err) { 150 if err = os.MkdirAll(sfa.localLocation, os.ModePerm); err != nil { 151 return err 152 } 153 } 154 } 155 156 if sfa.stageLocationType == local { 157 if _, err = os.Stat(sfa.stageInfo.Location); os.IsNotExist(err) { 158 if err = os.MkdirAll(sfa.stageInfo.Location, os.ModePerm); err != nil { 159 return err 160 } 161 } 162 } 163 164 if err = sfa.updateFileMetadataWithPresignedURL(); err != nil { 165 return err 166 } 167 168 smallFileMetas := make([]*fileMetadata, 0) 169 largeFileMetas := make([]*fileMetadata, 0) 170 171 for _, meta := range sfa.fileMetadata { 172 meta.overwrite = sfa.overwrite 173 meta.sfa = sfa 174 meta.options = sfa.options 175 if sfa.stageLocationType != local { 176 sizeThreshold := sfa.options.MultiPartThreshold 177 meta.options.MultiPartThreshold = sizeThreshold 178 if meta.srcFileSize > sizeThreshold && sfa.commandType == uploadCommand { 179 meta.parallel = sfa.parallel 180 largeFileMetas = append(largeFileMetas, meta) 181 } else { 182 meta.parallel = 1 183 smallFileMetas = append(smallFileMetas, meta) 184 } 185 } else { 186 meta.parallel = 1 187 smallFileMetas = append(smallFileMetas, meta) 188 } 189 } 190 191 if sfa.commandType == uploadCommand { 192 if err = sfa.upload(largeFileMetas, smallFileMetas); err != nil { 193 return err 194 } 195 } else { 196 if err = sfa.download(smallFileMetas); err != nil { 197 return err 198 } 199 } 200 201 return nil 202 } 203 204 func (sfa *snowflakeFileTransferAgent) parseCommand() error { 205 var err error 206 if sfa.data.Command != "" { 207 sfa.commandType = commandType(sfa.data.Command) 208 } else { 209 sfa.commandType = unknownCommand 210 } 211 212 sfa.initEncryptionMaterial() 213 if len(sfa.data.SrcLocations) == 0 { 214 return (&SnowflakeError{ 215 Number: ErrInvalidStageLocation, 216 SQLState: sfa.data.SQLState, 217 QueryID: sfa.data.QueryID, 218 Message: "failed to parse location", 219 }).exceptionTelemetry(sfa.sc) 220 } 221 sfa.srcLocations = sfa.data.SrcLocations 222 223 if sfa.commandType == uploadCommand { 224 if sfa.sourceStream != nil { 225 sfa.srcFiles = sfa.srcLocations // streaming PUT 226 } else { 227 sfa.srcFiles, err = sfa.expandFilenames(sfa.srcLocations) 228 if err != nil { 229 return err 230 } 231 } 232 sfa.autoCompress = sfa.data.AutoCompress 233 sfa.srcCompression = strings.ToLower(sfa.data.SourceCompression) 234 } else { 235 sfa.srcFiles = sfa.srcLocations 236 sfa.srcFileToEncryptionMaterial = make(map[string]*snowflakeFileEncryption) 237 if len(sfa.data.SrcLocations) == len(sfa.encryptionMaterial) { 238 for i, srcFile := range sfa.srcFiles { 239 sfa.srcFileToEncryptionMaterial[srcFile] = sfa.encryptionMaterial[i] 240 } 241 } else if len(sfa.encryptionMaterial) != 0 { 242 return (&SnowflakeError{ 243 Number: ErrInternalNotMatchEncryptMaterial, 244 SQLState: sfa.data.SQLState, 245 QueryID: sfa.data.QueryID, 246 Message: errMsgInternalNotMatchEncryptMaterial, 247 MessageArgs: []interface{}{len(sfa.data.SrcLocations), len(sfa.encryptionMaterial)}, 248 }).exceptionTelemetry(sfa.sc) 249 } 250 251 sfa.localLocation, err = expandUser(sfa.data.LocalLocation) 252 if err != nil { 253 return err 254 } 255 if fi, err := os.Stat(sfa.localLocation); err != nil || !fi.IsDir() { 256 return (&SnowflakeError{ 257 Number: ErrLocalPathNotDirectory, 258 SQLState: sfa.data.SQLState, 259 QueryID: sfa.data.QueryID, 260 Message: errMsgLocalPathNotDirectory, 261 MessageArgs: []interface{}{sfa.localLocation}, 262 }).exceptionTelemetry(sfa.sc) 263 } 264 } 265 266 sfa.parallel = 1 267 if sfa.data.Parallel != 0 { 268 sfa.parallel = sfa.data.Parallel 269 } 270 sfa.overwrite = sfa.data.Overwrite 271 sfa.stageLocationType = cloudType(strings.ToUpper(sfa.data.StageInfo.LocationType)) 272 sfa.stageInfo = &sfa.data.StageInfo 273 sfa.presignedURLs = make([]string, 0) 274 if len(sfa.data.PresignedURLs) != 0 { 275 sfa.presignedURLs = sfa.data.PresignedURLs 276 } 277 278 if sfa.getStorageClient(sfa.stageLocationType) == nil { 279 return (&SnowflakeError{ 280 Number: ErrInvalidStageFs, 281 SQLState: sfa.data.SQLState, 282 QueryID: sfa.data.QueryID, 283 Message: errMsgInvalidStageFs, 284 MessageArgs: []interface{}{sfa.stageLocationType}, 285 }).exceptionTelemetry(sfa.sc) 286 } 287 return nil 288 } 289 290 func (sfa *snowflakeFileTransferAgent) initEncryptionMaterial() { 291 sfa.encryptionMaterial = make([]*snowflakeFileEncryption, 0) 292 wrapper := sfa.data.EncryptionMaterial 293 294 if sfa.commandType == uploadCommand { 295 if wrapper.QueryID != "" { 296 sfa.encryptionMaterial = append(sfa.encryptionMaterial, &wrapper.snowflakeFileEncryption) 297 } 298 } else { 299 for _, encmat := range wrapper.EncryptionMaterials { 300 if encmat.QueryID != "" { 301 sfa.encryptionMaterial = append(sfa.encryptionMaterial, &encmat) 302 } 303 } 304 } 305 } 306 307 func (sfa *snowflakeFileTransferAgent) expandFilenames(locations []string) ([]string, error) { 308 canonicalLocations := make([]string, 0) 309 for _, fileName := range locations { 310 if sfa.commandType == uploadCommand { 311 var err error 312 fileName, err = expandUser(fileName) 313 if err != nil { 314 return []string{}, err 315 } 316 if !filepath.IsAbs(fileName) { 317 cwd, err := getDirectory() 318 if err != nil { 319 return []string{}, err 320 } 321 fileName = filepath.Join(cwd, fileName) 322 } 323 if isWindows && len(fileName) > 2 && fileName[0] == '/' && fileName[2] == ':' { 324 // Windows path: /C:/data/file1.txt where it starts with slash 325 // followed by a drive letter and colon. 326 fileName = fileName[1:] 327 } 328 files, err := filepath.Glob(fileName) 329 if err != nil { 330 return []string{}, err 331 } 332 canonicalLocations = append(canonicalLocations, files...) 333 } else { 334 canonicalLocations = append(canonicalLocations, fileName) 335 } 336 } 337 return canonicalLocations, nil 338 } 339 340 func (sfa *snowflakeFileTransferAgent) initFileMetadata() error { 341 sfa.fileMetadata = []*fileMetadata{} 342 if sfa.commandType == uploadCommand { 343 if len(sfa.srcFiles) == 0 { 344 fileName := sfa.data.SrcLocations 345 return (&SnowflakeError{ 346 Number: ErrFileNotExists, 347 SQLState: sfa.data.SQLState, 348 QueryID: sfa.data.QueryID, 349 Message: errMsgFileNotExists, 350 MessageArgs: []interface{}{fileName}, 351 }).exceptionTelemetry(sfa.sc) 352 } 353 if sfa.sourceStream != nil { 354 fileName := sfa.srcFiles[0] 355 srcFileSize := int64(sfa.sourceStream.Len()) 356 sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ 357 name: baseName(fileName), 358 srcFileName: fileName, 359 srcStream: sfa.sourceStream, 360 srcFileSize: srcFileSize, 361 stageLocationType: sfa.stageLocationType, 362 stageInfo: sfa.stageInfo, 363 }) 364 } else { 365 for i, fileName := range sfa.srcFiles { 366 fi, err := os.Stat(fileName) 367 if os.IsNotExist(err) { 368 return (&SnowflakeError{ 369 Number: ErrFileNotExists, 370 SQLState: sfa.data.SQLState, 371 QueryID: sfa.data.QueryID, 372 Message: errMsgFileNotExists, 373 MessageArgs: []interface{}{fileName}, 374 }).exceptionTelemetry(sfa.sc) 375 } else if fi.IsDir() { 376 return (&SnowflakeError{ 377 Number: ErrFileNotExists, 378 SQLState: sfa.data.SQLState, 379 QueryID: sfa.data.QueryID, 380 Message: errMsgFileNotExists, 381 MessageArgs: []interface{}{fileName}, 382 }).exceptionTelemetry(sfa.sc) 383 } 384 sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ 385 name: baseName(fileName), 386 srcFileName: fileName, 387 srcFileSize: fi.Size(), 388 stageLocationType: sfa.stageLocationType, 389 stageInfo: sfa.stageInfo, 390 }) 391 if len(sfa.encryptionMaterial) > 0 { 392 sfa.fileMetadata[i].encryptionMaterial = sfa.encryptionMaterial[0] 393 } 394 } 395 } 396 397 if len(sfa.encryptionMaterial) > 0 { 398 for _, meta := range sfa.fileMetadata { 399 meta.encryptionMaterial = sfa.encryptionMaterial[0] 400 } 401 } 402 } else if sfa.commandType == downloadCommand { 403 for _, fileName := range sfa.srcFiles { 404 if len(fileName) > 0 { 405 firstPathSep := strings.Index(fileName, "/") 406 dstFileName := fileName 407 if firstPathSep >= 0 { 408 dstFileName = fileName[firstPathSep+1:] 409 } 410 sfa.fileMetadata = append(sfa.fileMetadata, &fileMetadata{ 411 name: baseName(fileName), 412 srcFileName: fileName, 413 dstFileName: dstFileName, 414 stageLocationType: sfa.stageLocationType, 415 stageInfo: sfa.stageInfo, 416 localLocation: sfa.localLocation, 417 }) 418 } 419 } 420 // TODO is this necessary? 421 for _, meta := range sfa.fileMetadata { 422 fileName := meta.srcFileName 423 if val, ok := sfa.srcFileToEncryptionMaterial[fileName]; ok { 424 meta.encryptionMaterial = val 425 } 426 } 427 } 428 429 return nil 430 } 431 432 func (sfa *snowflakeFileTransferAgent) processFileCompressionType() error { 433 var userSpecifiedSourceCompression *compressionType 434 var autoDetect bool 435 if sfa.srcCompression == "auto_detect" { 436 autoDetect = true 437 } else if sfa.srcCompression == "none" { 438 autoDetect = false 439 } else { 440 userSpecifiedSourceCompression = lookupByMimeSubType(sfa.srcCompression) 441 if userSpecifiedSourceCompression == nil || !userSpecifiedSourceCompression.isSupported { 442 return (&SnowflakeError{ 443 Number: ErrCompressionNotSupported, 444 SQLState: sfa.data.SQLState, 445 QueryID: sfa.data.QueryID, 446 Message: errMsgFeatureNotSupported, 447 MessageArgs: []interface{}{userSpecifiedSourceCompression}, 448 }).exceptionTelemetry(sfa.sc) 449 } 450 autoDetect = false 451 } 452 453 gzipCompression := compressionTypes["GZIP"] 454 for _, meta := range sfa.fileMetadata { 455 fileName := meta.srcFileName 456 var currentFileCompressionType *compressionType 457 if autoDetect { 458 currentFileCompressionType = lookupByExtension(filepath.Ext(fileName)) 459 if currentFileCompressionType == nil { 460 var mtype *mimetype.MIME 461 var err error 462 if meta.srcStream != nil { 463 r := getReaderFromBuffer(&meta.srcStream) 464 mtype, err = mimetype.DetectReader(r) 465 if err != nil { 466 return err 467 } 468 io.ReadAll(r) // flush out tee buffer 469 } else { 470 mtype, err = mimetype.DetectFile(fileName) 471 if err != nil { 472 return err 473 } 474 } 475 currentFileCompressionType = lookupByExtension(mtype.Extension()) 476 } 477 478 if currentFileCompressionType != nil && !currentFileCompressionType.isSupported { 479 return (&SnowflakeError{ 480 Number: ErrCompressionNotSupported, 481 SQLState: sfa.data.SQLState, 482 QueryID: sfa.data.QueryID, 483 Message: errMsgFeatureNotSupported, 484 MessageArgs: []interface{}{userSpecifiedSourceCompression}, 485 }).exceptionTelemetry(sfa.sc) 486 } 487 } else { 488 currentFileCompressionType = userSpecifiedSourceCompression 489 } 490 491 if currentFileCompressionType != nil { 492 meta.srcCompressionType = currentFileCompressionType 493 if currentFileCompressionType.isSupported { 494 meta.dstCompressionType = currentFileCompressionType 495 meta.requireCompress = false 496 meta.dstFileName = meta.name 497 } else { 498 return (&SnowflakeError{ 499 Number: ErrCompressionNotSupported, 500 SQLState: sfa.data.SQLState, 501 QueryID: sfa.data.QueryID, 502 Message: errMsgFeatureNotSupported, 503 MessageArgs: []interface{}{userSpecifiedSourceCompression}, 504 }).exceptionTelemetry(sfa.sc) 505 } 506 } else { 507 meta.requireCompress = sfa.autoCompress 508 meta.srcCompressionType = nil 509 if sfa.autoCompress { 510 dstFileName := meta.name + compressionTypes["GZIP"].fileExtension 511 meta.dstFileName = dstFileName 512 meta.dstCompressionType = gzipCompression 513 } else { 514 meta.dstFileName = meta.name 515 meta.dstCompressionType = nil 516 } 517 } 518 } 519 return nil 520 } 521 522 func (sfa *snowflakeFileTransferAgent) updateFileMetadataWithPresignedURL() error { 523 // presigned URL only applies to GCS 524 if sfa.stageLocationType == gcsClient { 525 if sfa.commandType == uploadCommand { 526 filePathToBeReplaced := sfa.getLocalFilePathFromCommand(sfa.command) 527 for _, meta := range sfa.fileMetadata { 528 filePathToBeReplacedWith := strings.TrimRight(filePathToBeReplaced, meta.dstFileName) + meta.dstFileName 529 commandWithSingleFile := strings.ReplaceAll(sfa.command, filePathToBeReplaced, filePathToBeReplacedWith) 530 req := execRequest{ 531 SQLText: commandWithSingleFile, 532 } 533 headers := getHeaders() 534 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 535 jsonBody, err := json.Marshal(req) 536 if err != nil { 537 return err 538 } 539 data, err := sfa.sc.rest.FuncPostQuery( 540 sfa.sc.ctx, 541 sfa.sc.rest, 542 &url.Values{}, 543 headers, 544 jsonBody, 545 sfa.sc.rest.RequestTimeout, 546 getOrGenerateRequestIDFromContext(sfa.sc.ctx), 547 sfa.sc.cfg) 548 if err != nil { 549 return err 550 } 551 552 if data.Data.StageInfo != (execResponseStageInfo{}) { 553 meta.stageInfo = &data.Data.StageInfo 554 meta.presignedURL = nil 555 if meta.stageInfo.PresignedURL != "" { 556 meta.presignedURL, err = url.Parse(meta.stageInfo.PresignedURL) 557 if err != nil { 558 return err 559 } 560 } 561 } 562 } 563 } else if sfa.commandType == downloadCommand { 564 for i, meta := range sfa.fileMetadata { 565 if len(sfa.presignedURLs) > 0 { 566 var err error 567 meta.presignedURL, err = url.Parse(sfa.presignedURLs[i]) 568 if err != nil { 569 return err 570 } 571 } else { 572 meta.presignedURL = nil 573 } 574 } 575 } else { 576 return (&SnowflakeError{ 577 Number: ErrCommandNotRecognized, 578 SQLState: sfa.data.SQLState, 579 QueryID: sfa.data.QueryID, 580 Message: errMsgCommandNotRecognized, 581 MessageArgs: []interface{}{sfa.commandType}, 582 }).exceptionTelemetry(sfa.sc) 583 } 584 } 585 return nil 586 } 587 588 func (sfa *snowflakeFileTransferAgent) transferAccelerateConfig() error { 589 if sfa.stageLocationType == s3Client { 590 s3Util := new(snowflakeS3Client) 591 s3Loc, err := s3Util.extractBucketNameAndPath(sfa.stageInfo.Location) 592 if err != nil { 593 return err 594 } 595 s3Cli, err := s3Util.createClient(sfa.stageInfo, false) 596 if err != nil { 597 return err 598 } 599 client, ok := s3Cli.(*s3.Client) 600 if !ok { 601 return (&SnowflakeError{ 602 Number: ErrFailedToConvertToS3Client, 603 SQLState: sfa.data.SQLState, 604 QueryID: sfa.data.QueryID, 605 Message: errMsgFailedToConvertToS3Client, 606 }).exceptionTelemetry(sfa.sc) 607 } 608 ret, err := client.GetBucketAccelerateConfiguration(context.Background(), &s3.GetBucketAccelerateConfigurationInput{ 609 Bucket: &s3Loc.bucketName, 610 }) 611 sfa.useAccelerateEndpoint = ret != nil && ret.Status == "Enabled" 612 if err != nil { 613 var ae smithy.APIError 614 if errors.As(err, &ae) { 615 if ae.ErrorCode() == "AccessDenied" { 616 return nil 617 } else if ae.ErrorCode() == "MethodNotAllowed" { 618 return nil 619 } else if strings.EqualFold(ae.ErrorCode(), "UnsupportedArgument") { 620 // In AWS China and US Gov partitions, Transfer Acceleration is not supported 621 // https://docs.amazonaws.cn/en_us/aws/latest/userguide/s3.html#feature-diff 622 // https://docs.aws.amazon.com/govcloud-us/latest/UserGuide/govcloud-s3.html 623 return nil 624 } 625 } 626 return err 627 } 628 } 629 return nil 630 } 631 632 func (sfa *snowflakeFileTransferAgent) getLocalFilePathFromCommand(command string) string { 633 if len(command) == 0 || !strings.Contains(command, fileProtocol) { 634 return "" 635 } 636 if !regexp.MustCompile(putRegexp).Match([]byte(command)) { 637 return "" 638 } 639 640 filePathBeginIdx := strings.Index(command, fileProtocol) 641 isFilePathQuoted := command[filePathBeginIdx-1] == '\'' 642 filePathBeginIdx += len(fileProtocol) 643 var filePathEndIdx int 644 filePath := "" 645 646 if isFilePathQuoted { 647 filePathEndIdx = filePathBeginIdx + strings.Index(command[filePathBeginIdx:], "'") 648 if filePathEndIdx > filePathBeginIdx { 649 filePath = command[filePathBeginIdx:filePathEndIdx] 650 } 651 } else { 652 indexList := make([]int, 0) 653 delims := []rune{' ', '\n', ';'} 654 for _, delim := range delims { 655 index := strings.Index(command[filePathBeginIdx:], string(delim)) 656 if index != -1 { 657 indexList = append(indexList, index) 658 } 659 } 660 filePathEndIdx = -1 661 if getMin(indexList) != -1 { 662 filePathEndIdx = filePathBeginIdx + getMin(indexList) 663 } 664 if filePathEndIdx > filePathBeginIdx { 665 filePath = command[filePathBeginIdx:filePathEndIdx] 666 } else { 667 filePath = command[filePathBeginIdx:] 668 } 669 } 670 return filePath 671 } 672 673 func (sfa *snowflakeFileTransferAgent) upload( 674 largeFileMetadata []*fileMetadata, 675 smallFileMetadata []*fileMetadata) error { 676 client, err := sfa.getStorageClient(sfa.stageLocationType). 677 createClient(sfa.stageInfo, sfa.useAccelerateEndpoint) 678 if err != nil { 679 return err 680 } 681 for _, meta := range smallFileMetadata { 682 meta.client = client 683 } 684 for _, meta := range largeFileMetadata { 685 meta.client = client 686 } 687 688 if len(smallFileMetadata) > 0 { 689 logger.Infof("uploading %v small files", len(smallFileMetadata)) 690 if err = sfa.uploadFilesParallel(smallFileMetadata); err != nil { 691 return err 692 } 693 } 694 if len(largeFileMetadata) > 0 { 695 logger.Infof("uploading %v large files", len(largeFileMetadata)) 696 if err = sfa.uploadFilesSequential(largeFileMetadata); err != nil { 697 return err 698 } 699 } 700 return nil 701 } 702 703 func (sfa *snowflakeFileTransferAgent) download( 704 fileMetadata []*fileMetadata) error { 705 client, err := sfa.getStorageClient(sfa.stageLocationType). 706 createClient(sfa.stageInfo, sfa.useAccelerateEndpoint) 707 if err != nil { 708 return err 709 } 710 for _, meta := range fileMetadata { 711 meta.client = client 712 } 713 714 logger.WithContext(sfa.sc.ctx).Infof("downloading %v files", len(fileMetadata)) 715 if err = sfa.downloadFilesParallel(fileMetadata); err != nil { 716 return err 717 } 718 return nil 719 } 720 721 func (sfa *snowflakeFileTransferAgent) uploadFilesParallel(fileMetas []*fileMetadata) error { 722 idx := 0 723 fileMetaLen := len(fileMetas) 724 var err error 725 for idx < fileMetaLen { 726 endOfIdx := intMin(fileMetaLen, idx+int(sfa.parallel)) 727 targetMeta := fileMetas[idx:endOfIdx] 728 for { 729 var wg sync.WaitGroup 730 results := make([]*fileMetadata, len(targetMeta)) 731 errors := make([]error, len(targetMeta)) 732 for i, meta := range targetMeta { 733 wg.Add(1) 734 go func(k int, m *fileMetadata) { 735 defer wg.Done() 736 results[k], errors[k] = sfa.uploadOneFile(m) 737 }(i, meta) 738 } 739 wg.Wait() 740 741 // append errors with no result associated to separate array 742 var errorMessages []string 743 for i, result := range results { 744 if result == nil { 745 if errors[i] == nil { 746 errorMessages = append(errorMessages, "unknown error") 747 } else { 748 errorMessages = append(errorMessages, errors[i].Error()) 749 } 750 } 751 } 752 if errorMessages != nil { 753 // sort the error messages to be more deterministic as the goroutines may finish in different order each time 754 sort.Strings(errorMessages) 755 return fmt.Errorf("errors during file upload:\n%v", strings.Join(errorMessages, "\n")) 756 } 757 758 retryMeta := make([]*fileMetadata, 0) 759 for i, result := range results { 760 result.errorDetails = errors[i] 761 if result.resStatus == renewToken || result.resStatus == renewPresignedURL { 762 retryMeta = append(retryMeta, result) 763 } else { 764 sfa.results = append(sfa.results, result) 765 } 766 } 767 if len(retryMeta) == 0 { 768 break 769 } 770 771 needRenewToken := false 772 for _, result := range retryMeta { 773 if result.resStatus == renewToken { 774 needRenewToken = true 775 } 776 } 777 778 if needRenewToken { 779 client, err := sfa.renewExpiredClient() 780 if err != nil { 781 return err 782 } 783 for _, result := range retryMeta { 784 result.client = client 785 } 786 if endOfIdx < fileMetaLen { 787 for i := idx + int(sfa.parallel); i < fileMetaLen; i++ { 788 fileMetas[i].client = client 789 } 790 } 791 } 792 793 for _, result := range retryMeta { 794 if result.resStatus == renewPresignedURL { 795 sfa.updateFileMetadataWithPresignedURL() 796 break 797 } 798 } 799 targetMeta = retryMeta 800 } 801 if endOfIdx == fileMetaLen { 802 break 803 } 804 idx += int(sfa.parallel) 805 } 806 return err 807 } 808 809 func (sfa *snowflakeFileTransferAgent) uploadFilesSequential(fileMetas []*fileMetadata) error { 810 idx := 0 811 fileMetaLen := len(fileMetas) 812 for idx < fileMetaLen { 813 res, err := sfa.uploadOneFile(fileMetas[idx]) 814 if err != nil { 815 return err 816 } 817 818 if res.resStatus == renewToken { 819 client, err := sfa.renewExpiredClient() 820 if err != nil { 821 return err 822 } 823 for i := idx; i < fileMetaLen; i++ { 824 fileMetas[i].client = client 825 } 826 continue 827 } else if res.resStatus == renewPresignedURL { 828 sfa.updateFileMetadataWithPresignedURL() 829 continue 830 } 831 832 sfa.results = append(sfa.results, res) 833 idx++ 834 if injectWaitPut > 0 { 835 time.Sleep(injectWaitPut) 836 } 837 } 838 return nil 839 } 840 841 func (sfa *snowflakeFileTransferAgent) uploadOneFile(meta *fileMetadata) (*fileMetadata, error) { 842 meta.realSrcFileName = meta.srcFileName 843 tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") 844 if err != nil { 845 return nil, err 846 } 847 meta.tmpDir = tmpDir 848 defer os.RemoveAll(tmpDir) // cleanup 849 850 fileUtil := new(snowflakeFileUtil) 851 if meta.requireCompress { 852 if meta.srcStream != nil { 853 meta.realSrcStream, _, err = fileUtil.compressFileWithGzipFromStream(&meta.srcStream) 854 } else { 855 meta.realSrcFileName, _, err = fileUtil.compressFileWithGzip(meta.srcFileName, tmpDir) 856 } 857 if err != nil { 858 return nil, err 859 } 860 } 861 862 if meta.srcStream != nil { 863 if meta.realSrcStream != nil { 864 meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.realSrcStream) 865 } else { 866 meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForStream(&meta.srcStream) 867 } 868 } else { 869 meta.sha256Digest, meta.uploadSize, err = fileUtil.getDigestAndSizeForFile(meta.realSrcFileName) 870 } 871 if err != nil { 872 return meta, err 873 } 874 875 client := sfa.getStorageClient(sfa.stageLocationType) 876 if err = client.uploadOneFileWithRetry(meta); err != nil { 877 return meta, err 878 } 879 return meta, nil 880 } 881 882 func (sfa *snowflakeFileTransferAgent) downloadFilesParallel(fileMetas []*fileMetadata) error { 883 idx := 0 884 fileMetaLen := len(fileMetas) 885 var err error 886 for idx < fileMetaLen { 887 endOfIdx := intMin(fileMetaLen, idx+int(sfa.parallel)) 888 targetMeta := fileMetas[idx:endOfIdx] 889 for { 890 var wg sync.WaitGroup 891 results := make([]*fileMetadata, len(targetMeta)) 892 errors := make([]error, len(targetMeta)) 893 for i, meta := range targetMeta { 894 wg.Add(1) 895 go func(k int, m *fileMetadata) { 896 defer wg.Done() 897 results[k], errors[k] = sfa.downloadOneFile(m) 898 }(i, meta) 899 } 900 wg.Wait() 901 902 retryMeta := make([]*fileMetadata, 0) 903 for i, result := range results { 904 result.errorDetails = errors[i] 905 if result.resStatus == renewToken || result.resStatus == renewPresignedURL { 906 retryMeta = append(retryMeta, result) 907 } else { 908 sfa.results = append(sfa.results, result) 909 } 910 } 911 if len(retryMeta) == 0 { 912 break 913 } 914 logger.WithContext(sfa.sc.ctx).Infof("%v retries found", len(retryMeta)) 915 916 needRenewToken := false 917 for _, result := range retryMeta { 918 if result.resStatus == renewToken { 919 needRenewToken = true 920 } 921 logger.WithContext(sfa.sc.ctx).Infof( 922 "retying download file %v with status %v", 923 result.name, result.resStatus) 924 } 925 926 if needRenewToken { 927 client, err := sfa.renewExpiredClient() 928 if err != nil { 929 return err 930 } 931 for _, result := range retryMeta { 932 result.client = client 933 } 934 if endOfIdx < fileMetaLen { 935 for i := idx + int(sfa.parallel); i < fileMetaLen; i++ { 936 fileMetas[i].client = client 937 } 938 } 939 } 940 941 for _, result := range retryMeta { 942 if result.resStatus == renewPresignedURL { 943 sfa.updateFileMetadataWithPresignedURL() 944 break 945 } 946 } 947 targetMeta = retryMeta 948 } 949 if endOfIdx == fileMetaLen { 950 break 951 } 952 idx += int(sfa.parallel) 953 } 954 return err 955 } 956 957 func (sfa *snowflakeFileTransferAgent) downloadOneFile(meta *fileMetadata) (*fileMetadata, error) { 958 tmpDir, err := os.MkdirTemp(sfa.sc.cfg.TmpDirPath, "") 959 if err != nil { 960 return nil, err 961 } 962 meta.tmpDir = tmpDir 963 defer os.RemoveAll(tmpDir) // cleanup 964 client := sfa.getStorageClient(sfa.stageLocationType) 965 if err = client.downloadOneFile(meta); err != nil { 966 meta.dstFileSize = -1 967 if !meta.resStatus.isSet() { 968 meta.resStatus = errStatus 969 } 970 meta.errorDetails = fmt.Errorf(err.Error() + ", file=" + meta.dstFileName) 971 return meta, err 972 } 973 return meta, nil 974 } 975 976 func (sfa *snowflakeFileTransferAgent) getStorageClient(stageLocationType cloudType) storageUtil { 977 if stageLocationType == local { 978 return &localUtil{} 979 } else if stageLocationType == s3Client || stageLocationType == azureClient || stageLocationType == gcsClient { 980 return &remoteStorageUtil{} 981 } 982 return nil 983 } 984 985 func (sfa *snowflakeFileTransferAgent) renewExpiredClient() (cloudClient, error) { 986 data, err := sfa.sc.exec( 987 sfa.sc.ctx, 988 sfa.command, 989 false, 990 false, 991 false, 992 []driver.NamedValue{}) 993 if err != nil { 994 return nil, err 995 } 996 storageClient := sfa.getStorageClient(sfa.stageLocationType) 997 return storageClient.createClient(&data.Data.StageInfo, sfa.useAccelerateEndpoint) 998 } 999 1000 func (sfa *snowflakeFileTransferAgent) result() (*execResponse, error) { 1001 // inherit old response data 1002 data := sfa.data 1003 rowset := make([]fileTransferResultType, 0) 1004 if sfa.commandType == uploadCommand { 1005 if len(sfa.results) > 0 { 1006 for _, meta := range sfa.results { 1007 var srcCompressionType, dstCompressionType *compressionType 1008 if meta.srcCompressionType != nil { 1009 srcCompressionType = meta.srcCompressionType 1010 } else { 1011 srcCompressionType = &compressionType{ 1012 name: "NONE", 1013 } 1014 } 1015 if meta.dstCompressionType != nil { 1016 dstCompressionType = meta.dstCompressionType 1017 } else { 1018 dstCompressionType = &compressionType{ 1019 name: "NONE", 1020 } 1021 } 1022 errorDetails := meta.errorDetails 1023 srcFileSize := meta.srcFileSize 1024 dstFileSize := meta.dstFileSize 1025 if sfa.options.RaisePutGetError && errorDetails != nil { 1026 return nil, (&SnowflakeError{ 1027 Number: ErrFailedToUploadToStage, 1028 SQLState: sfa.data.SQLState, 1029 QueryID: sfa.data.QueryID, 1030 Message: errorDetails.Error(), 1031 }).exceptionTelemetry(sfa.sc) 1032 } 1033 rowset = append(rowset, fileTransferResultType{ 1034 meta.name, 1035 meta.srcFileName, 1036 meta.dstFileName, 1037 srcFileSize, 1038 dstFileSize, 1039 srcCompressionType, 1040 dstCompressionType, 1041 meta.resStatus, 1042 meta.errorDetails, 1043 }) 1044 } 1045 sort.Slice(rowset, func(i, j int) bool { 1046 return rowset[i].srcFileName < rowset[j].srcFileName 1047 }) 1048 ccrs := make([][]*string, 0, len(rowset)) 1049 for _, rs := range rowset { 1050 srcFileSize := fmt.Sprintf("%v", rs.srcFileSize) 1051 dstFileSize := fmt.Sprintf("%v", rs.dstFileSize) 1052 resStatus := rs.resStatus.String() 1053 errorStr := "" 1054 if rs.errorDetails != nil { 1055 errorStr = rs.errorDetails.Error() 1056 } 1057 ccrs = append(ccrs, []*string{ 1058 &rs.srcFileName, 1059 &rs.dstFileName, 1060 &srcFileSize, 1061 &dstFileSize, 1062 &rs.srcCompressionType.name, 1063 &rs.dstCompressionType.name, 1064 &resStatus, 1065 &errorStr, 1066 }) 1067 } 1068 data.RowSet = ccrs 1069 cc := make([]chunkRowType, len(ccrs)) 1070 populateJSONRowSet(cc, ccrs) 1071 data.QueryResultFormat = "json" 1072 rt := []execResponseRowType{ 1073 {Name: "source", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1074 {Name: "target", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1075 {Name: "source_size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, 1076 {Name: "target_size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, 1077 {Name: "source_compression", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1078 {Name: "target_compression", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1079 {Name: "status", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1080 {Name: "message", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1081 } 1082 data.RowType = rt 1083 return &execResponse{Data: *data, Success: true}, nil 1084 } 1085 } else { // DOWNLOAD 1086 if len(sfa.results) > 0 { 1087 for _, meta := range sfa.results { 1088 dstFileSize := meta.dstFileSize 1089 errorDetails := meta.errorDetails 1090 if sfa.options.RaisePutGetError && errorDetails != nil { 1091 return nil, (&SnowflakeError{ 1092 Number: ErrFailedToDownloadFromStage, 1093 SQLState: sfa.data.SQLState, 1094 QueryID: sfa.data.QueryID, 1095 Message: errorDetails.Error(), 1096 }).exceptionTelemetry(sfa.sc) 1097 } 1098 1099 rowset = append(rowset, fileTransferResultType{ 1100 "", "", meta.dstFileName, 0, dstFileSize, 1101 nil, nil, meta.resStatus, meta.errorDetails, 1102 }) 1103 } 1104 sort.Slice(rowset, func(i, j int) bool { 1105 return rowset[i].srcFileName < rowset[j].srcFileName 1106 }) 1107 ccrs := make([][]*string, 0, len(rowset)) 1108 for _, rs := range rowset { 1109 dstFileSize := fmt.Sprintf("%v", rs.dstFileSize) 1110 resStatus := rs.resStatus.String() 1111 errorStr := "" 1112 if rs.errorDetails != nil { 1113 errorStr = rs.errorDetails.Error() 1114 } 1115 ccrs = append(ccrs, []*string{ 1116 &rs.dstFileName, 1117 &dstFileSize, 1118 &resStatus, 1119 &errorStr, 1120 }) 1121 } 1122 data.RowSet = ccrs 1123 cc := make([]chunkRowType, len(ccrs)) 1124 populateJSONRowSet(cc, ccrs) 1125 data.QueryResultFormat = "json" 1126 rt := []execResponseRowType{ 1127 {Name: "file", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1128 {Name: "size", ByteLength: 64, Length: 64, Type: "FIXED", Scale: 0, Nullable: false}, 1129 {Name: "status", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1130 {Name: "message", ByteLength: 10000, Length: 10000, Type: "TEXT", Scale: 0, Nullable: false}, 1131 } 1132 data.RowType = rt 1133 return &execResponse{Data: *data, Success: true}, nil 1134 } 1135 } 1136 return nil, (&SnowflakeError{ 1137 Number: ErrNotImplemented, 1138 SQLState: sfa.data.SQLState, 1139 QueryID: sfa.data.QueryID, 1140 Message: errMsgNotImplemented, 1141 }).exceptionTelemetry(sfa.sc) 1142 } 1143 1144 func isFileTransfer(query string) bool { 1145 putRe := regexp.MustCompile(putRegexp) 1146 getRe := regexp.MustCompile(getRegexp) 1147 return putRe.Match([]byte(query)) || getRe.Match([]byte(query)) 1148 } 1149 1150 type snowflakeProgressPercentage struct { 1151 filename string 1152 fileSize float64 1153 outputStream *io.Writer 1154 showProgressBar bool 1155 seenSoFar int64 1156 done bool 1157 startTime time.Time 1158 } 1159 1160 func (spp *snowflakeProgressPercentage) call(bytesAmount int64) { 1161 if spp.outputStream != nil { 1162 spp.seenSoFar += bytesAmount 1163 percentage := spp.percent(spp.seenSoFar, spp.fileSize) 1164 if !spp.done { 1165 spp.done = spp.updateProgress(spp.filename, spp.startTime, spp.fileSize, percentage, spp.outputStream, spp.showProgressBar) 1166 } 1167 } 1168 } 1169 1170 func (spp *snowflakeProgressPercentage) percent(seenSoFar int64, size float64) float64 { 1171 if float64(seenSoFar) >= size || size <= 0 { 1172 return 1.0 1173 } 1174 return float64(seenSoFar) / size 1175 } 1176 1177 func (spp *snowflakeProgressPercentage) updateProgress(filename string, startTime time.Time, totalSize float64, progress float64, outputStream *io.Writer, showProgressBar bool) bool { 1178 barLength := 10 1179 totalSize /= mb 1180 status := "" 1181 elapsedTime := time.Since(startTime) 1182 1183 var throughput float64 1184 if elapsedTime != 0.0 { 1185 throughput = totalSize / elapsedTime.Seconds() 1186 } 1187 1188 if progress < 0 { 1189 progress = 0 1190 status = "Halt...\r\n" 1191 } 1192 if progress >= 1 { 1193 status = fmt.Sprintf("Done (%.3fs, %.2fMB/s)", elapsedTime.Seconds(), throughput) 1194 } 1195 if status == "" && showProgressBar { 1196 status = fmt.Sprintf("(%.3fsm %.2fMB/s)", elapsedTime.Seconds(), throughput) 1197 } 1198 if status != "" { 1199 block := int(math.Round(float64(barLength) * progress)) 1200 text := fmt.Sprintf("\r%v(%.2fMB): [%v] %.2f%% %v ", filename, totalSize, strings.Repeat("#", block)+strings.Repeat("-", barLength-block), progress*100, status) 1201 (*outputStream).Write([]byte(text)) 1202 } 1203 return progress == 1.0 1204 }