github.com/snowflakedb/gosnowflake@v1.9.0/s3_storage_client_test.go (about)

     1  // Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"path"
    13  	"strconv"
    14  	"testing"
    15  
    16  	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
    17  	"github.com/aws/aws-sdk-go-v2/service/s3"
    18  	"github.com/aws/smithy-go"
    19  )
    20  
    21  type tcBucketPath struct {
    22  	in     string
    23  	bucket string
    24  	path   string
    25  }
    26  
    27  func TestExtractBucketNameAndPath(t *testing.T) {
    28  	s3util := new(snowflakeS3Client)
    29  	testcases := []tcBucketPath{
    30  		{"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"},
    31  		{"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"},
    32  		{"sfc-eng-regression/", "sfc-eng-regression", ""},
    33  		{"sfc-eng-regression//", "sfc-eng-regression", "/"},
    34  		{"sfc-eng-regression///", "sfc-eng-regression", "//"},
    35  	}
    36  	for _, test := range testcases {
    37  		t.Run(test.in, func(t *testing.T) {
    38  			s3Loc, err := s3util.extractBucketNameAndPath(test.in)
    39  			if err != nil {
    40  				t.Error(err)
    41  			}
    42  			if s3Loc.bucketName != test.bucket {
    43  				t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, s3Loc.bucketName)
    44  			}
    45  			if s3Loc.s3Path != test.path {
    46  				t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, s3Loc.s3Path)
    47  			}
    48  		})
    49  	}
    50  }
    51  
    52  type mockUploadObjectAPI func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error)
    53  
    54  func (m mockUploadObjectAPI) Upload(
    55  	ctx context.Context,
    56  	params *s3.PutObjectInput,
    57  	optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
    58  	return m(ctx, params, optFns...)
    59  }
    60  
    61  func TestUploadOneFileToS3WSAEConnAborted(t *testing.T) {
    62  	info := execResponseStageInfo{
    63  		Location:     "sfc-customer-stage/rwyi-testacco/users/9220/",
    64  		LocationType: "S3",
    65  	}
    66  	initialParallel := int64(100)
    67  	dir, err := os.Getwd()
    68  	if err != nil {
    69  		t.Error(err)
    70  	}
    71  
    72  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
    73  	if err != nil {
    74  		t.Error(err)
    75  	}
    76  	uploadMeta := fileMetadata{
    77  		name:              "data1.txt.gz",
    78  		stageLocationType: "S3",
    79  		noSleepingTime:    false,
    80  		parallel:          initialParallel,
    81  		client:            s3Cli,
    82  		sha256Digest:      "123456789abcdef",
    83  		stageInfo:         &info,
    84  		dstFileName:       "data1.txt.gz",
    85  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
    86  		overwrite:         true,
    87  		options: &SnowflakeFileTransferOptions{
    88  			MultiPartThreshold: dataSizeThreshold,
    89  		},
    90  		mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
    91  			return nil, &smithy.GenericAPIError{
    92  				Code:    errNoWsaeconnaborted,
    93  				Message: "mock err, connection aborted",
    94  			}
    95  		}),
    96  	}
    97  
    98  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
    99  	fi, err := os.Stat(uploadMeta.srcFileName)
   100  	if err != nil {
   101  		t.Error(err)
   102  	}
   103  	uploadMeta.uploadSize = fi.Size()
   104  
   105  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   106  	if err == nil {
   107  		t.Error("should have raised an error")
   108  	}
   109  	if uploadMeta.lastMaxConcurrency == 0 {
   110  		t.Fatalf("expected concurrency. got: 0")
   111  	}
   112  	if uploadMeta.lastMaxConcurrency != int(initialParallel/defaultMaxRetry) {
   113  		t.Fatalf("expected last max concurrency to be: %v, got: %v",
   114  			int(initialParallel/defaultMaxRetry), uploadMeta.lastMaxConcurrency)
   115  	}
   116  
   117  	initialParallel = 4
   118  	uploadMeta.parallel = initialParallel
   119  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   120  	if err == nil {
   121  		t.Error("should have raised an error")
   122  	}
   123  	if uploadMeta.lastMaxConcurrency == 0 {
   124  		t.Fatalf("expected no last max concurrency. got: %v",
   125  			uploadMeta.lastMaxConcurrency)
   126  	}
   127  	if uploadMeta.lastMaxConcurrency != 1 {
   128  		t.Fatalf("expected last max concurrency to be: 1, got: %v",
   129  			uploadMeta.lastMaxConcurrency)
   130  	}
   131  }
   132  
   133  func TestUploadOneFileToS3ConnReset(t *testing.T) {
   134  	info := execResponseStageInfo{
   135  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   136  		LocationType: "S3",
   137  	}
   138  	initialParallel := int64(100)
   139  	dir, err := os.Getwd()
   140  	if err != nil {
   141  		t.Error(err)
   142  	}
   143  
   144  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   145  	if err != nil {
   146  		t.Error(err)
   147  	}
   148  	uploadMeta := fileMetadata{
   149  		name:              "data1.txt.gz",
   150  		stageLocationType: "S3",
   151  		noSleepingTime:    true,
   152  		parallel:          initialParallel,
   153  		client:            s3Cli,
   154  		sha256Digest:      "123456789abcdef",
   155  		stageInfo:         &info,
   156  		dstFileName:       "data1.txt.gz",
   157  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
   158  		overwrite:         true,
   159  		options: &SnowflakeFileTransferOptions{
   160  			MultiPartThreshold: dataSizeThreshold,
   161  		},
   162  		mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
   163  			return nil, &smithy.GenericAPIError{
   164  				Code:    strconv.Itoa(-1),
   165  				Message: "mock err, connection aborted",
   166  			}
   167  		}),
   168  	}
   169  
   170  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   171  	fi, err := os.Stat(uploadMeta.srcFileName)
   172  	if err != nil {
   173  		t.Error(err)
   174  	}
   175  	uploadMeta.uploadSize = fi.Size()
   176  
   177  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   178  	if err == nil {
   179  		t.Error("should have raised an error")
   180  	}
   181  	if uploadMeta.lastMaxConcurrency != 0 {
   182  		t.Fatalf("expected no concurrency. got: %v",
   183  			uploadMeta.lastMaxConcurrency)
   184  	}
   185  }
   186  
   187  func TestUploadFileWithS3UploadFailedError(t *testing.T) {
   188  	info := execResponseStageInfo{
   189  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   190  		LocationType: "S3",
   191  	}
   192  	initialParallel := int64(100)
   193  	dir, err := os.Getwd()
   194  	if err != nil {
   195  		t.Error(err)
   196  	}
   197  
   198  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   199  	if err != nil {
   200  		t.Error(err)
   201  	}
   202  	uploadMeta := fileMetadata{
   203  		name:              "data1.txt.gz",
   204  		stageLocationType: "S3",
   205  		noSleepingTime:    true,
   206  		parallel:          initialParallel,
   207  		client:            s3Cli,
   208  		sha256Digest:      "123456789abcdef",
   209  		stageInfo:         &info,
   210  		dstFileName:       "data1.txt.gz",
   211  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
   212  		overwrite:         true,
   213  		options: &SnowflakeFileTransferOptions{
   214  			MultiPartThreshold: dataSizeThreshold,
   215  		},
   216  		mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
   217  			return nil, &smithy.GenericAPIError{
   218  				Code: expiredToken,
   219  				Message: "An error occurred (ExpiredToken) when calling the " +
   220  					"operation: The provided token has expired.",
   221  			}
   222  		}),
   223  	}
   224  
   225  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   226  	fi, err := os.Stat(uploadMeta.srcFileName)
   227  	if err != nil {
   228  		t.Error(err)
   229  	}
   230  	uploadMeta.uploadSize = fi.Size()
   231  
   232  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   233  	if err != nil {
   234  		t.Error(err)
   235  	}
   236  	if uploadMeta.resStatus != renewToken {
   237  		t.Fatalf("expected %v result status, got: %v",
   238  			renewToken, uploadMeta.resStatus)
   239  	}
   240  }
   241  
   242  type mockHeaderAPI func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error)
   243  
   244  func (m mockHeaderAPI) HeadObject(
   245  	ctx context.Context,
   246  	params *s3.HeadObjectInput,
   247  	optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   248  	return m(ctx, params, optFns...)
   249  }
   250  
   251  func TestGetHeadExpiryError(t *testing.T) {
   252  	meta := fileMetadata{
   253  		client:    s3.New(s3.Options{}),
   254  		stageInfo: &execResponseStageInfo{Location: ""},
   255  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   256  			return nil, &smithy.GenericAPIError{
   257  				Code: expiredToken,
   258  			}
   259  		}),
   260  	}
   261  	if header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
   262  		t.Fatalf("expected null header, got: %v", header)
   263  	}
   264  	if meta.resStatus != renewToken {
   265  		t.Fatalf("expected %v result status, got: %v",
   266  			renewToken, meta.resStatus)
   267  	}
   268  }
   269  
   270  func TestGetHeaderUnexpectedError(t *testing.T) {
   271  	meta := fileMetadata{
   272  		client:    s3.New(s3.Options{}),
   273  		stageInfo: &execResponseStageInfo{Location: ""},
   274  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   275  			return nil, &smithy.GenericAPIError{
   276  				Code: "-1",
   277  			}
   278  		}),
   279  	}
   280  	if header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
   281  		t.Fatalf("expected null header, got: %v", header)
   282  	}
   283  	if meta.resStatus != errStatus {
   284  		t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus)
   285  	}
   286  }
   287  
   288  func TestGetHeaderNonApiError(t *testing.T) {
   289  	meta := fileMetadata{
   290  		client:    s3.New(s3.Options{}),
   291  		stageInfo: &execResponseStageInfo{Location: ""},
   292  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   293  			return nil, errors.New("something went wrong here")
   294  		}),
   295  	}
   296  
   297  	header, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt")
   298  	assertNilE(t, header, fmt.Sprintf("expected header to be nil, actual: %v", header))
   299  	assertNotNilE(t, err, "expected err to not be nil")
   300  	assertEqualE(t, meta.resStatus, errStatus, fmt.Sprintf("expected %v result status for non-APIerror, got: %v", errStatus, meta.resStatus))
   301  }
   302  
   303  func TestGetHeaderNotFoundError(t *testing.T) {
   304  	meta := fileMetadata{
   305  		client:    s3.New(s3.Options{}),
   306  		stageInfo: &execResponseStageInfo{Location: ""},
   307  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   308  			return nil, &smithy.GenericAPIError{
   309  				Code: notFound,
   310  			}
   311  		}),
   312  	}
   313  
   314  	_, err := new(snowflakeS3Client).getFileHeader(&meta, "file.txt")
   315  	if err != nil && err.Error() != "could not find file" {
   316  		t.Error(err)
   317  	}
   318  
   319  	if meta.resStatus != notFoundFile {
   320  		t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus)
   321  	}
   322  }
   323  
   324  type mockDownloadObjectAPI func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error)
   325  
   326  func (m mockDownloadObjectAPI) Download(
   327  	ctx context.Context,
   328  	w io.WriterAt,
   329  	params *s3.GetObjectInput,
   330  	optFns ...func(*manager.Downloader)) (int64, error) {
   331  	return m(ctx, w, params, optFns...)
   332  }
   333  
   334  func TestDownloadFileWithS3TokenExpired(t *testing.T) {
   335  	info := execResponseStageInfo{
   336  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   337  		LocationType: "S3",
   338  	}
   339  	dir, err := os.Getwd()
   340  	if err != nil {
   341  		t.Error(err)
   342  	}
   343  
   344  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   345  	if err != nil {
   346  		t.Error(err)
   347  	}
   348  
   349  	downloadMeta := fileMetadata{
   350  		name:              "data1.txt.gz",
   351  		stageLocationType: "S3",
   352  		noSleepingTime:    true,
   353  		client:            s3Cli,
   354  		stageInfo:         &info,
   355  		dstFileName:       "data1.txt.gz",
   356  		overwrite:         true,
   357  		srcFileName:       "data1.txt.gz",
   358  		localLocation:     dir,
   359  		options: &SnowflakeFileTransferOptions{
   360  			MultiPartThreshold: dataSizeThreshold,
   361  		},
   362  		mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) {
   363  			return 0, &smithy.GenericAPIError{
   364  				Code: expiredToken,
   365  				Message: "An error occurred (ExpiredToken) when calling the " +
   366  					"operation: The provided token has expired.",
   367  			}
   368  		}),
   369  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   370  			return &s3.HeadObjectOutput{}, nil
   371  		}),
   372  	}
   373  	err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
   374  	if err == nil {
   375  		t.Error("should have raised an error")
   376  	}
   377  	if downloadMeta.resStatus != renewToken {
   378  		t.Fatalf("expected %v result status, got: %v",
   379  			renewToken, downloadMeta.resStatus)
   380  	}
   381  }
   382  
   383  func TestDownloadFileWithS3ConnReset(t *testing.T) {
   384  	info := execResponseStageInfo{
   385  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   386  		LocationType: "S3",
   387  	}
   388  	dir, err := os.Getwd()
   389  	if err != nil {
   390  		t.Error(err)
   391  	}
   392  
   393  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   394  	if err != nil {
   395  		t.Error(err)
   396  	}
   397  
   398  	downloadMeta := fileMetadata{
   399  		name:              "data1.txt.gz",
   400  		stageLocationType: "S3",
   401  		noSleepingTime:    true,
   402  		client:            s3Cli,
   403  		stageInfo:         &info,
   404  		dstFileName:       "data1.txt.gz",
   405  		overwrite:         true,
   406  		srcFileName:       "data1.txt.gz",
   407  		localLocation:     dir,
   408  		options: &SnowflakeFileTransferOptions{
   409  			MultiPartThreshold: dataSizeThreshold,
   410  		},
   411  		mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) {
   412  			return 0, &smithy.GenericAPIError{
   413  				Code:    strconv.Itoa(-1),
   414  				Message: "mock err, connection aborted",
   415  			}
   416  		}),
   417  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   418  			return &s3.HeadObjectOutput{}, nil
   419  		}),
   420  	}
   421  	err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
   422  	if err == nil {
   423  		t.Error("should have raised an error")
   424  	}
   425  	if downloadMeta.lastMaxConcurrency != 0 {
   426  		t.Fatalf("expected no concurrency. got: %v",
   427  			downloadMeta.lastMaxConcurrency)
   428  	}
   429  }
   430  
   431  func TestDownloadOneFileToS3WSAEConnAborted(t *testing.T) {
   432  	info := execResponseStageInfo{
   433  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   434  		LocationType: "S3",
   435  	}
   436  	dir, err := os.Getwd()
   437  	if err != nil {
   438  		t.Error(err)
   439  	}
   440  
   441  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   442  	if err != nil {
   443  		t.Error(err)
   444  	}
   445  
   446  	downloadMeta := fileMetadata{
   447  		name:              "data1.txt.gz",
   448  		stageLocationType: "S3",
   449  		noSleepingTime:    true,
   450  		client:            s3Cli,
   451  		stageInfo:         &info,
   452  		dstFileName:       "data1.txt.gz",
   453  		overwrite:         true,
   454  		srcFileName:       "data1.txt.gz",
   455  		localLocation:     dir,
   456  		options: &SnowflakeFileTransferOptions{
   457  			MultiPartThreshold: dataSizeThreshold,
   458  		},
   459  		mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) {
   460  			return 0, &smithy.GenericAPIError{
   461  				Code:    errNoWsaeconnaborted,
   462  				Message: "mock err, connection aborted",
   463  			}
   464  		}),
   465  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   466  			return &s3.HeadObjectOutput{}, nil
   467  		}),
   468  	}
   469  	err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
   470  	if err == nil {
   471  		t.Error("should have raised an error")
   472  	}
   473  
   474  	if downloadMeta.resStatus != needRetryWithLowerConcurrency {
   475  		t.Fatalf("expected %v result status, got: %v",
   476  			needRetryWithLowerConcurrency, downloadMeta.resStatus)
   477  	}
   478  }
   479  
   480  func TestDownloadOneFileToS3Failed(t *testing.T) {
   481  	info := execResponseStageInfo{
   482  		Location:     "sfc-teststage/rwyitestacco/users/1234/",
   483  		LocationType: "S3",
   484  	}
   485  	dir, err := os.Getwd()
   486  	if err != nil {
   487  		t.Error(err)
   488  	}
   489  
   490  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   491  	if err != nil {
   492  		t.Error(err)
   493  	}
   494  
   495  	downloadMeta := fileMetadata{
   496  		name:              "data1.txt.gz",
   497  		stageLocationType: "S3",
   498  		noSleepingTime:    true,
   499  		client:            s3Cli,
   500  		stageInfo:         &info,
   501  		dstFileName:       "data1.txt.gz",
   502  		overwrite:         true,
   503  		srcFileName:       "data1.txt.gz",
   504  		localLocation:     dir,
   505  		options: &SnowflakeFileTransferOptions{
   506  			MultiPartThreshold: dataSizeThreshold,
   507  		},
   508  		mockDownloader: mockDownloadObjectAPI(func(ctx context.Context, w io.WriterAt, params *s3.GetObjectInput, optFns ...func(*manager.Downloader)) (int64, error) {
   509  			return 0, errors.New("Failed to upload file")
   510  		}),
   511  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   512  			return &s3.HeadObjectOutput{}, nil
   513  		}),
   514  	}
   515  	err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
   516  	if err == nil {
   517  		t.Error("should have raised an error")
   518  	}
   519  
   520  	if downloadMeta.resStatus != needRetry {
   521  		t.Fatalf("expected %v result status, got: %v",
   522  			needRetry, downloadMeta.resStatus)
   523  	}
   524  }
   525  
   526  func TestUploadFileToS3ClientCastFail(t *testing.T) {
   527  	info := execResponseStageInfo{
   528  		Location:     "sfc-customer-stage/rwyi-testacco/users/9220/",
   529  		LocationType: "S3",
   530  	}
   531  	dir, err := os.Getwd()
   532  	if err != nil {
   533  		t.Error(err)
   534  	}
   535  
   536  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   537  	if err != nil {
   538  		t.Error(err)
   539  	}
   540  	uploadMeta := fileMetadata{
   541  		name:              "data1.txt.gz",
   542  		stageLocationType: "S3",
   543  		noSleepingTime:    false,
   544  		client:            azureCli,
   545  		sha256Digest:      "123456789abcdef",
   546  		stageInfo:         &info,
   547  		dstFileName:       "data1.txt.gz",
   548  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
   549  		overwrite:         true,
   550  		options: &SnowflakeFileTransferOptions{
   551  			MultiPartThreshold: dataSizeThreshold,
   552  		},
   553  	}
   554  
   555  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   556  	fi, err := os.Stat(uploadMeta.srcFileName)
   557  	if err != nil {
   558  		t.Error(err)
   559  	}
   560  	uploadMeta.uploadSize = fi.Size()
   561  
   562  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   563  	if err == nil {
   564  		t.Fatal("should have failed")
   565  	}
   566  }
   567  
   568  func TestGetHeaderClientCastFail(t *testing.T) {
   569  	info := execResponseStageInfo{
   570  		Location:     "sfc-customer-stage/rwyi-testacco/users/9220/",
   571  		LocationType: "S3",
   572  	}
   573  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   574  	if err != nil {
   575  		t.Error(err)
   576  	}
   577  
   578  	meta := fileMetadata{
   579  		client:    azureCli,
   580  		stageInfo: &execResponseStageInfo{Location: ""},
   581  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   582  			return nil, &smithy.GenericAPIError{
   583  				Code: notFound,
   584  			}
   585  		}),
   586  	}
   587  
   588  	_, err = new(snowflakeS3Client).getFileHeader(&meta, "file.txt")
   589  	if err == nil {
   590  		t.Fatal("should have failed")
   591  	}
   592  }
   593  
   594  func TestS3UploadRetryWithHeaderNotFound(t *testing.T) {
   595  	info := execResponseStageInfo{
   596  		Location:     "sfc-customer-stage/rwyi-testacco/users/9220/",
   597  		LocationType: "S3",
   598  	}
   599  	initialParallel := int64(100)
   600  	dir, err := os.Getwd()
   601  	if err != nil {
   602  		t.Error(err)
   603  	}
   604  
   605  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   606  	if err != nil {
   607  		t.Error(err)
   608  	}
   609  	uploadMeta := fileMetadata{
   610  		name:              "data1.txt.gz",
   611  		stageLocationType: "S3",
   612  		noSleepingTime:    false,
   613  		parallel:          initialParallel,
   614  		client:            s3Cli,
   615  		sha256Digest:      "123456789abcdef",
   616  		stageInfo:         &info,
   617  		dstFileName:       "data1.txt.gz",
   618  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
   619  		overwrite:         true,
   620  		options: &SnowflakeFileTransferOptions{
   621  			MultiPartThreshold: dataSizeThreshold,
   622  		},
   623  		mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
   624  			return &manager.UploadOutput{
   625  				Location: "https://sfc-customer-stage/rwyi-testacco/users/9220/data1.txt.gz",
   626  			}, nil
   627  		}),
   628  		mockHeader: mockHeaderAPI(func(ctx context.Context, params *s3.HeadObjectInput, optFns ...func(*s3.Options)) (*s3.HeadObjectOutput, error) {
   629  			return nil, &smithy.GenericAPIError{
   630  				Code: notFound,
   631  			}
   632  		}),
   633  	}
   634  
   635  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   636  	fi, err := os.Stat(uploadMeta.srcFileName)
   637  	if err != nil {
   638  		t.Error(err)
   639  	}
   640  	uploadMeta.uploadSize = fi.Size()
   641  
   642  	err = new(remoteStorageUtil).uploadOneFileWithRetry(&uploadMeta)
   643  	if err != nil {
   644  		t.Error(err)
   645  	}
   646  
   647  	if uploadMeta.resStatus != errStatus {
   648  		t.Fatalf("expected %v result status, got: %v", errStatus, uploadMeta.resStatus)
   649  	}
   650  }
   651  
   652  func TestS3UploadStreamFailed(t *testing.T) {
   653  	info := execResponseStageInfo{
   654  		Location:     "sfc-customer-stage/rwyi-testacco/users/9220/",
   655  		LocationType: "S3",
   656  	}
   657  	initialParallel := int64(100)
   658  	src := []byte{65, 66, 67}
   659  
   660  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   661  	if err != nil {
   662  		t.Error(err)
   663  	}
   664  
   665  	uploadMeta := fileMetadata{
   666  		name:              "data1.txt.gz",
   667  		stageLocationType: "S3",
   668  		noSleepingTime:    true,
   669  		parallel:          initialParallel,
   670  		client:            s3Cli,
   671  		sha256Digest:      "123456789abcdef",
   672  		stageInfo:         &info,
   673  		dstFileName:       "data1.txt.gz",
   674  		srcStream:         bytes.NewBuffer(src),
   675  		overwrite:         true,
   676  		options: &SnowflakeFileTransferOptions{
   677  			MultiPartThreshold: dataSizeThreshold,
   678  		},
   679  		mockUploader: mockUploadObjectAPI(func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*manager.Uploader)) (*manager.UploadOutput, error) {
   680  			return nil, errors.New("unexpected error uploading file")
   681  		}),
   682  	}
   683  
   684  	uploadMeta.realSrcStream = uploadMeta.srcStream
   685  
   686  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   687  	if err == nil {
   688  		t.Fatal("should have failed")
   689  	}
   690  }
   691  
   692  func TestConvertContentLength(t *testing.T) {
   693  	someInt := int64(1)
   694  	tcs := []struct {
   695  		contentLength any
   696  		desc          string
   697  		expected      int64
   698  	}{
   699  		{
   700  			contentLength: someInt,
   701  			desc:          "int",
   702  			expected:      1,
   703  		},
   704  		{
   705  			contentLength: &someInt,
   706  			desc:          "pointer",
   707  			expected:      1,
   708  		},
   709  		{
   710  			contentLength: float64(1),
   711  			desc:          "another type",
   712  			expected:      0,
   713  		},
   714  	}
   715  	for _, tc := range tcs {
   716  		t.Run(tc.desc, func(t *testing.T) {
   717  			actual := convertContentLength(tc.contentLength)
   718  			assertEqualF(t, actual, tc.expected, fmt.Sprintf("expected %v (%T) but got %v (%T)", actual, actual, tc.expected, tc.expected))
   719  		})
   720  	}
   721  }