github.com/snowflakedb/gosnowflake@v1.9.0/azure_storage_client_test.go (about)

     1  // Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"io"
    11  	"net/http"
    12  	"os"
    13  	"path"
    14  	"testing"
    15  
    16  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
    17  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
    18  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
    19  )
    20  
    21  func TestExtractContainerNameAndPath(t *testing.T) {
    22  	azureUtil := new(snowflakeAzureClient)
    23  	testcases := []tcBucketPath{
    24  		{"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"},
    25  		{"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"},
    26  		{"sfc-eng-regression/", "sfc-eng-regression", ""},
    27  		{"sfc-eng-regression//", "sfc-eng-regression", "/"},
    28  		{"sfc-eng-regression///", "sfc-eng-regression", "//"},
    29  	}
    30  	for _, test := range testcases {
    31  		t.Run(test.in, func(t *testing.T) {
    32  			azureLoc, err := azureUtil.extractContainerNameAndPath(test.in)
    33  			if err != nil {
    34  				t.Error(err)
    35  			}
    36  			if azureLoc.containerName != test.bucket {
    37  				t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, azureLoc.containerName)
    38  			}
    39  			if azureLoc.path != test.path {
    40  				t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, azureLoc.path)
    41  			}
    42  		})
    43  	}
    44  }
    45  
    46  func TestUnitDetectAzureTokenExpireError(t *testing.T) {
    47  	azureUtil := new(snowflakeAzureClient)
    48  	dd := &execResponseData{}
    49  	invalidSig := &execResponse{
    50  		Data:    *dd,
    51  		Message: "Signature not valid in the specified time frame",
    52  		Code:    "403",
    53  		Success: true,
    54  	}
    55  	ba, err := json.Marshal(invalidSig)
    56  	if err != nil {
    57  		panic(err)
    58  	}
    59  	resp := &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}
    60  	if !azureUtil.detectAzureTokenExpireError(resp) {
    61  		t.Fatal("expected token expired")
    62  	}
    63  
    64  	invalidAuth := &execResponse{
    65  		Data:    *dd,
    66  		Message: "Server failed to authenticate the request",
    67  		Code:    "403",
    68  		Success: true,
    69  	}
    70  	ba, err = json.Marshal(invalidAuth)
    71  	if err != nil {
    72  		panic(err)
    73  	}
    74  	resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}
    75  	if !azureUtil.detectAzureTokenExpireError(resp) {
    76  		t.Fatal("expected token expired")
    77  	}
    78  
    79  	resp = &http.Response{
    80  		StatusCode: http.StatusForbidden,
    81  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    82  	}
    83  	if azureUtil.detectAzureTokenExpireError(resp) {
    84  		t.Fatal("invalid body")
    85  	}
    86  
    87  	invalidMessage := &execResponse{
    88  		Data:    *dd,
    89  		Message: "unauthorized",
    90  		Code:    "403",
    91  		Success: true,
    92  	}
    93  	ba, err = json.Marshal(invalidMessage)
    94  	if err != nil {
    95  		panic(err)
    96  	}
    97  	resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}
    98  	if azureUtil.detectAzureTokenExpireError(resp) {
    99  		t.Fatal("incorrect message")
   100  	}
   101  
   102  	resp = &http.Response{
   103  		StatusCode: http.StatusOK,
   104  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}}}
   105  
   106  	if azureUtil.detectAzureTokenExpireError(resp) {
   107  		t.Fatal("status code is success. expected false.")
   108  	}
   109  }
   110  
   111  type azureObjectAPIMock struct {
   112  	UploadStreamFunc  func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
   113  	UploadFileFunc    func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
   114  	DownloadFileFunc  func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
   115  	GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
   116  }
   117  
   118  func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
   119  	return c.UploadStreamFunc(ctx, body, o)
   120  }
   121  
   122  func (c *azureObjectAPIMock) UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) {
   123  	return c.UploadFileFunc(ctx, file, o)
   124  }
   125  
   126  func (c *azureObjectAPIMock) GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   127  	return c.GetPropertiesFunc(ctx, o)
   128  }
   129  
   130  func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) {
   131  	return c.DownloadFileFunc(ctx, file, o)
   132  }
   133  
   134  func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
   135  	info := execResponseStageInfo{
   136  		Location:     "azblob/storage/users/456/",
   137  		LocationType: "AZURE",
   138  	}
   139  	initialParallel := int64(100)
   140  	dir, err := os.Getwd()
   141  	if err != nil {
   142  		t.Error(err)
   143  	}
   144  	encMat := snowflakeFileEncryption{
   145  		QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==",
   146  		QueryID:             "01abc874-0406-1bf0-0000-53b10668e056",
   147  		SMKID:               92019681909886,
   148  	}
   149  
   150  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   151  	if err != nil {
   152  		t.Error(err)
   153  	}
   154  	uploadMeta := fileMetadata{
   155  		name:               "data1.txt.gz",
   156  		stageLocationType:  "AZURE",
   157  		noSleepingTime:     true,
   158  		parallel:           initialParallel,
   159  		client:             azureCli,
   160  		sha256Digest:       "123456789abcdef",
   161  		stageInfo:          &info,
   162  		dstFileName:        "data1.txt.gz",
   163  		srcFileName:        path.Join(dir, "/test_data/put_get_1.txt"),
   164  		encryptionMaterial: &encMat,
   165  		overwrite:          true,
   166  		dstCompressionType: compressionTypes["GZIP"],
   167  		options: &SnowflakeFileTransferOptions{
   168  			MultiPartThreshold: dataSizeThreshold,
   169  		},
   170  		mockAzureClient: &azureObjectAPIMock{
   171  			UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) {
   172  				return azblob.UploadFileResponse{}, errors.New("unexpected error uploading file")
   173  			},
   174  		},
   175  	}
   176  
   177  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   178  	fi, err := os.Stat(uploadMeta.srcFileName)
   179  	if err != nil {
   180  		t.Error(err)
   181  	}
   182  	uploadMeta.uploadSize = fi.Size()
   183  
   184  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   185  	if err == nil {
   186  		t.Fatal("should have failed")
   187  	}
   188  }
   189  
   190  func TestUploadStreamWithAzureUploadFailedError(t *testing.T) {
   191  	info := execResponseStageInfo{
   192  		Location:     "azblob/storage/users/456/",
   193  		LocationType: "AZURE",
   194  	}
   195  	initialParallel := int64(100)
   196  	src := []byte{65, 66, 67}
   197  	encMat := snowflakeFileEncryption{
   198  		QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==",
   199  		QueryID:             "01abc874-0406-1bf0-0000-53b10668e056",
   200  		SMKID:               92019681909886,
   201  	}
   202  
   203  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   204  	if err != nil {
   205  		t.Error(err)
   206  	}
   207  	uploadMeta := fileMetadata{
   208  		name:               "data1.txt.gz",
   209  		stageLocationType:  "AZURE",
   210  		noSleepingTime:     true,
   211  		parallel:           initialParallel,
   212  		client:             azureCli,
   213  		sha256Digest:       "123456789abcdef",
   214  		stageInfo:          &info,
   215  		dstFileName:        "data1.txt.gz",
   216  		srcStream:          bytes.NewBuffer(src),
   217  		encryptionMaterial: &encMat,
   218  		overwrite:          true,
   219  		dstCompressionType: compressionTypes["GZIP"],
   220  		options: &SnowflakeFileTransferOptions{
   221  			MultiPartThreshold: dataSizeThreshold,
   222  		},
   223  		mockAzureClient: &azureObjectAPIMock{
   224  			UploadStreamFunc: func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
   225  				return azblob.UploadStreamResponse{}, errors.New("unexpected error uploading file")
   226  			},
   227  		},
   228  	}
   229  
   230  	uploadMeta.realSrcStream = uploadMeta.srcStream
   231  
   232  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   233  	if err == nil {
   234  		t.Fatal("should have failed")
   235  	}
   236  }
   237  
   238  func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) {
   239  	info := execResponseStageInfo{
   240  		Location:     "azblob/storage/users/456/",
   241  		LocationType: "AZURE",
   242  	}
   243  	initialParallel := int64(100)
   244  	dir, err := os.Getwd()
   245  	if err != nil {
   246  		t.Error(err)
   247  	}
   248  
   249  	dd := &execResponseData{}
   250  	invalidSig := &execResponse{
   251  		Data:    *dd,
   252  		Message: "Signature not valid in the specified time frame",
   253  		Code:    "403",
   254  		Success: true,
   255  	}
   256  	ba, err := json.Marshal(invalidSig)
   257  	if err != nil {
   258  		panic(err)
   259  	}
   260  
   261  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   262  	if err != nil {
   263  		t.Error(err)
   264  	}
   265  	uploadMeta := fileMetadata{
   266  		name:               "data1.txt.gz",
   267  		stageLocationType:  "AZURE",
   268  		noSleepingTime:     true,
   269  		parallel:           initialParallel,
   270  		client:             azureCli,
   271  		sha256Digest:       "123456789abcdef",
   272  		stageInfo:          &info,
   273  		dstFileName:        "data1.txt.gz",
   274  		srcFileName:        path.Join(dir, "/test_data/put_get_1.txt"),
   275  		overwrite:          true,
   276  		dstCompressionType: compressionTypes["GZIP"],
   277  		options: &SnowflakeFileTransferOptions{
   278  			MultiPartThreshold: dataSizeThreshold,
   279  		},
   280  		mockAzureClient: &azureObjectAPIMock{
   281  			UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) {
   282  				return azblob.UploadFileResponse{}, &azcore.ResponseError{
   283  					ErrorCode:   "12345",
   284  					StatusCode:  403,
   285  					RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}},
   286  				}
   287  			},
   288  		},
   289  	}
   290  
   291  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   292  	fi, err := os.Stat(uploadMeta.srcFileName)
   293  	if err != nil {
   294  		t.Error(err)
   295  	}
   296  	uploadMeta.uploadSize = fi.Size()
   297  
   298  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  
   303  	if uploadMeta.resStatus != renewToken {
   304  		t.Fatalf("expected %v result status, got: %v",
   305  			renewToken, uploadMeta.resStatus)
   306  	}
   307  }
   308  
   309  func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) {
   310  	info := execResponseStageInfo{
   311  		Location:     "azblob/storage/users/456/",
   312  		LocationType: "AZURE",
   313  	}
   314  	initialParallel := int64(100)
   315  	dir, err := os.Getwd()
   316  	if err != nil {
   317  		t.Error(err)
   318  	}
   319  
   320  	dd := &execResponseData{}
   321  	invalidSig := &execResponse{
   322  		Data:    *dd,
   323  		Message: "Server Error",
   324  		Code:    "500",
   325  		Success: true,
   326  	}
   327  	ba, err := json.Marshal(invalidSig)
   328  	if err != nil {
   329  		panic(err)
   330  	}
   331  
   332  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   333  	if err != nil {
   334  		t.Error(err)
   335  	}
   336  	uploadMeta := fileMetadata{
   337  		name:               "data1.txt.gz",
   338  		stageLocationType:  "AZURE",
   339  		noSleepingTime:     false,
   340  		parallel:           initialParallel,
   341  		client:             azureCli,
   342  		sha256Digest:       "123456789abcdef",
   343  		stageInfo:          &info,
   344  		dstFileName:        "data1.txt.gz",
   345  		srcFileName:        path.Join(dir, "/test_data/put_get_1.txt"),
   346  		overwrite:          true,
   347  		dstCompressionType: compressionTypes["GZIP"],
   348  		options: &SnowflakeFileTransferOptions{
   349  			MultiPartThreshold: dataSizeThreshold,
   350  		},
   351  		mockAzureClient: &azureObjectAPIMock{
   352  			UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) {
   353  				return azblob.UploadFileResponse{}, &azcore.ResponseError{
   354  					ErrorCode:   "12345",
   355  					StatusCode:  500,
   356  					RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}},
   357  				}
   358  			},
   359  		},
   360  	}
   361  
   362  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   363  	fi, err := os.Stat(uploadMeta.srcFileName)
   364  	if err != nil {
   365  		t.Error(err)
   366  	}
   367  	uploadMeta.uploadSize = fi.Size()
   368  
   369  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   370  	if err == nil {
   371  		t.Fatal("should have raised an error")
   372  	}
   373  
   374  	if uploadMeta.resStatus != needRetry {
   375  		t.Fatalf("expected %v result status, got: %v",
   376  			needRetry, uploadMeta.resStatus)
   377  	}
   378  }
   379  
   380  func TestDownloadOneFileToAzureFailed(t *testing.T) {
   381  	info := execResponseStageInfo{
   382  		Location:     "azblob/rwyitestacco/users/1234/",
   383  		LocationType: "AZURE",
   384  	}
   385  	dir, err := os.Getwd()
   386  	if err != nil {
   387  		t.Error(err)
   388  	}
   389  
   390  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   391  	if err != nil {
   392  		t.Error(err)
   393  	}
   394  
   395  	downloadMeta := fileMetadata{
   396  		name:              "data1.txt.gz",
   397  		stageLocationType: "AZURE",
   398  		noSleepingTime:    true,
   399  		client:            azureCli,
   400  		stageInfo:         &info,
   401  		dstFileName:       "data1.txt.gz",
   402  		overwrite:         true,
   403  		srcFileName:       "data1.txt.gz",
   404  		localLocation:     dir,
   405  		options: &SnowflakeFileTransferOptions{
   406  			MultiPartThreshold: dataSizeThreshold,
   407  		},
   408  		mockAzureClient: &azureObjectAPIMock{
   409  			DownloadFileFunc: func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) {
   410  				return 0, errors.New("unexpected error uploading file")
   411  			},
   412  			GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   413  				return blob.GetPropertiesResponse{}, nil
   414  			},
   415  		},
   416  	}
   417  	err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
   418  	if err == nil {
   419  		t.Error("should have raised an error")
   420  	}
   421  }
   422  
   423  func TestGetFileHeaderErrorStatus(t *testing.T) {
   424  	info := execResponseStageInfo{
   425  		Location:     "azblob/teststage/users/34/",
   426  		LocationType: "AZURE",
   427  	}
   428  
   429  	azureCli, err := new(snowflakeAzureClient).createClient(&info, false)
   430  	if err != nil {
   431  		t.Error(err)
   432  	}
   433  
   434  	meta := fileMetadata{
   435  		client:    azureCli,
   436  		stageInfo: &info,
   437  		mockAzureClient: &azureObjectAPIMock{
   438  			GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   439  				return blob.GetPropertiesResponse{}, errors.New("failed to retrieve headers")
   440  			},
   441  		},
   442  	}
   443  
   444  	if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
   445  		t.Fatalf("expected null header, got: %v", header)
   446  	}
   447  	if meta.resStatus != errStatus {
   448  		t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus)
   449  	}
   450  
   451  	dd := &execResponseData{}
   452  	invalidSig := &execResponse{
   453  		Data:    *dd,
   454  		Message: "Not Found",
   455  		Code:    "404",
   456  		Success: true,
   457  	}
   458  	ba, err := json.Marshal(invalidSig)
   459  	if err != nil {
   460  		panic(err)
   461  	}
   462  
   463  	meta = fileMetadata{
   464  		client:    azureCli,
   465  		stageInfo: &info,
   466  		mockAzureClient: &azureObjectAPIMock{
   467  			GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   468  				return blob.GetPropertiesResponse{}, &azcore.ResponseError{
   469  					ErrorCode:   "BlobNotFound",
   470  					StatusCode:  404,
   471  					RawResponse: &http.Response{StatusCode: http.StatusNotFound, Body: &fakeResponseBody{body: ba}},
   472  				}
   473  			},
   474  		},
   475  	}
   476  
   477  	if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
   478  		t.Fatalf("expected null header, got: %v", header)
   479  	}
   480  	if meta.resStatus != notFoundFile {
   481  		t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus)
   482  	}
   483  
   484  	invalidSig = &execResponse{
   485  		Data:    *dd,
   486  		Message: "Unauthorized",
   487  		Code:    "403",
   488  		Success: true,
   489  	}
   490  	ba, err = json.Marshal(invalidSig)
   491  	if err != nil {
   492  		panic(err)
   493  	}
   494  	meta.mockAzureClient = &azureObjectAPIMock{
   495  		GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   496  			return blob.GetPropertiesResponse{}, &azcore.ResponseError{
   497  				StatusCode:  403,
   498  				RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}},
   499  			}
   500  		},
   501  	}
   502  
   503  	if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
   504  		t.Fatalf("expected null header, got: %v", header)
   505  	}
   506  	if meta.resStatus != renewToken {
   507  		t.Fatalf("expected %v result status, got: %v", renewToken, meta.resStatus)
   508  	}
   509  }
   510  
   511  func TestUploadFileToAzureClientCastFail(t *testing.T) {
   512  	info := execResponseStageInfo{
   513  		Location:     "azblob/rwyi-testacco/users/9220/",
   514  		LocationType: "AZURE",
   515  	}
   516  	dir, err := os.Getwd()
   517  	if err != nil {
   518  		t.Error(err)
   519  	}
   520  
   521  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   522  	if err != nil {
   523  		t.Error(err)
   524  	}
   525  	uploadMeta := fileMetadata{
   526  		name:              "data1.txt.gz",
   527  		stageLocationType: "AZURE",
   528  		noSleepingTime:    false,
   529  		client:            s3Cli,
   530  		sha256Digest:      "123456789abcdef",
   531  		stageInfo:         &info,
   532  		dstFileName:       "data1.txt.gz",
   533  		srcFileName:       path.Join(dir, "/test_data/put_get_1.txt"),
   534  		overwrite:         true,
   535  		options: &SnowflakeFileTransferOptions{
   536  			MultiPartThreshold: dataSizeThreshold,
   537  		},
   538  	}
   539  
   540  	uploadMeta.realSrcFileName = uploadMeta.srcFileName
   541  	fi, err := os.Stat(uploadMeta.srcFileName)
   542  	if err != nil {
   543  		t.Error(err)
   544  	}
   545  	uploadMeta.uploadSize = fi.Size()
   546  
   547  	err = new(remoteStorageUtil).uploadOneFile(&uploadMeta)
   548  	if err == nil {
   549  		t.Fatal("should have failed")
   550  	}
   551  }
   552  
   553  func TestAzureGetHeaderClientCastFail(t *testing.T) {
   554  	info := execResponseStageInfo{
   555  		Location:     "azblob/rwyi-testacco/users/9220/",
   556  		LocationType: "AZURE",
   557  	}
   558  	s3Cli, err := new(snowflakeS3Client).createClient(&info, false)
   559  	if err != nil {
   560  		t.Error(err)
   561  	}
   562  
   563  	meta := fileMetadata{
   564  		client:    s3Cli,
   565  		stageInfo: &execResponseStageInfo{Location: ""},
   566  		mockAzureClient: &azureObjectAPIMock{
   567  			GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) {
   568  				return blob.GetPropertiesResponse{}, nil
   569  			},
   570  		},
   571  	}
   572  
   573  	_, err = new(snowflakeAzureClient).getFileHeader(&meta, "file.txt")
   574  	if err == nil {
   575  		t.Fatal("should have failed")
   576  	}
   577  }