github.com/snowflakedb/gosnowflake@v1.9.0/azure_storage_client_test.go (about) 1 // Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "encoding/json" 9 "errors" 10 "io" 11 "net/http" 12 "os" 13 "path" 14 "testing" 15 16 "github.com/Azure/azure-sdk-for-go/sdk/azcore" 17 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" 18 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" 19 ) 20 21 func TestExtractContainerNameAndPath(t *testing.T) { 22 azureUtil := new(snowflakeAzureClient) 23 testcases := []tcBucketPath{ 24 {"sfc-eng-regression/test_sub_dir/", "sfc-eng-regression", "test_sub_dir/"}, 25 {"sfc-eng-regression/dir/test_stg/test_sub_dir/", "sfc-eng-regression", "dir/test_stg/test_sub_dir/"}, 26 {"sfc-eng-regression/", "sfc-eng-regression", ""}, 27 {"sfc-eng-regression//", "sfc-eng-regression", "/"}, 28 {"sfc-eng-regression///", "sfc-eng-regression", "//"}, 29 } 30 for _, test := range testcases { 31 t.Run(test.in, func(t *testing.T) { 32 azureLoc, err := azureUtil.extractContainerNameAndPath(test.in) 33 if err != nil { 34 t.Error(err) 35 } 36 if azureLoc.containerName != test.bucket { 37 t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.bucket, azureLoc.containerName) 38 } 39 if azureLoc.path != test.path { 40 t.Errorf("failed. in: %v, expected: %v, got: %v", test.in, test.path, azureLoc.path) 41 } 42 }) 43 } 44 } 45 46 func TestUnitDetectAzureTokenExpireError(t *testing.T) { 47 azureUtil := new(snowflakeAzureClient) 48 dd := &execResponseData{} 49 invalidSig := &execResponse{ 50 Data: *dd, 51 Message: "Signature not valid in the specified time frame", 52 Code: "403", 53 Success: true, 54 } 55 ba, err := json.Marshal(invalidSig) 56 if err != nil { 57 panic(err) 58 } 59 resp := &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} 60 if !azureUtil.detectAzureTokenExpireError(resp) { 61 t.Fatal("expected token expired") 62 } 63 64 invalidAuth := &execResponse{ 65 Data: *dd, 66 Message: "Server failed to authenticate the request", 67 Code: "403", 68 Success: true, 69 } 70 ba, err = json.Marshal(invalidAuth) 71 if err != nil { 72 panic(err) 73 } 74 resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} 75 if !azureUtil.detectAzureTokenExpireError(resp) { 76 t.Fatal("expected token expired") 77 } 78 79 resp = &http.Response{ 80 StatusCode: http.StatusForbidden, 81 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 82 } 83 if azureUtil.detectAzureTokenExpireError(resp) { 84 t.Fatal("invalid body") 85 } 86 87 invalidMessage := &execResponse{ 88 Data: *dd, 89 Message: "unauthorized", 90 Code: "403", 91 Success: true, 92 } 93 ba, err = json.Marshal(invalidMessage) 94 if err != nil { 95 panic(err) 96 } 97 resp = &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}} 98 if azureUtil.detectAzureTokenExpireError(resp) { 99 t.Fatal("incorrect message") 100 } 101 102 resp = &http.Response{ 103 StatusCode: http.StatusOK, 104 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}} 105 106 if azureUtil.detectAzureTokenExpireError(resp) { 107 t.Fatal("status code is success. expected false.") 108 } 109 } 110 111 type azureObjectAPIMock struct { 112 UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) 113 UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) 114 DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) 115 GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) 116 } 117 118 func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { 119 return c.UploadStreamFunc(ctx, body, o) 120 } 121 122 func (c *azureObjectAPIMock) UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { 123 return c.UploadFileFunc(ctx, file, o) 124 } 125 126 func (c *azureObjectAPIMock) GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 127 return c.GetPropertiesFunc(ctx, o) 128 } 129 130 func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) { 131 return c.DownloadFileFunc(ctx, file, o) 132 } 133 134 func TestUploadFileWithAzureUploadFailedError(t *testing.T) { 135 info := execResponseStageInfo{ 136 Location: "azblob/storage/users/456/", 137 LocationType: "AZURE", 138 } 139 initialParallel := int64(100) 140 dir, err := os.Getwd() 141 if err != nil { 142 t.Error(err) 143 } 144 encMat := snowflakeFileEncryption{ 145 QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", 146 QueryID: "01abc874-0406-1bf0-0000-53b10668e056", 147 SMKID: 92019681909886, 148 } 149 150 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 151 if err != nil { 152 t.Error(err) 153 } 154 uploadMeta := fileMetadata{ 155 name: "data1.txt.gz", 156 stageLocationType: "AZURE", 157 noSleepingTime: true, 158 parallel: initialParallel, 159 client: azureCli, 160 sha256Digest: "123456789abcdef", 161 stageInfo: &info, 162 dstFileName: "data1.txt.gz", 163 srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), 164 encryptionMaterial: &encMat, 165 overwrite: true, 166 dstCompressionType: compressionTypes["GZIP"], 167 options: &SnowflakeFileTransferOptions{ 168 MultiPartThreshold: dataSizeThreshold, 169 }, 170 mockAzureClient: &azureObjectAPIMock{ 171 UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { 172 return azblob.UploadFileResponse{}, errors.New("unexpected error uploading file") 173 }, 174 }, 175 } 176 177 uploadMeta.realSrcFileName = uploadMeta.srcFileName 178 fi, err := os.Stat(uploadMeta.srcFileName) 179 if err != nil { 180 t.Error(err) 181 } 182 uploadMeta.uploadSize = fi.Size() 183 184 err = new(remoteStorageUtil).uploadOneFile(&uploadMeta) 185 if err == nil { 186 t.Fatal("should have failed") 187 } 188 } 189 190 func TestUploadStreamWithAzureUploadFailedError(t *testing.T) { 191 info := execResponseStageInfo{ 192 Location: "azblob/storage/users/456/", 193 LocationType: "AZURE", 194 } 195 initialParallel := int64(100) 196 src := []byte{65, 66, 67} 197 encMat := snowflakeFileEncryption{ 198 QueryStageMasterKey: "abCdEFO0upIT36dAxGsa0w==", 199 QueryID: "01abc874-0406-1bf0-0000-53b10668e056", 200 SMKID: 92019681909886, 201 } 202 203 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 204 if err != nil { 205 t.Error(err) 206 } 207 uploadMeta := fileMetadata{ 208 name: "data1.txt.gz", 209 stageLocationType: "AZURE", 210 noSleepingTime: true, 211 parallel: initialParallel, 212 client: azureCli, 213 sha256Digest: "123456789abcdef", 214 stageInfo: &info, 215 dstFileName: "data1.txt.gz", 216 srcStream: bytes.NewBuffer(src), 217 encryptionMaterial: &encMat, 218 overwrite: true, 219 dstCompressionType: compressionTypes["GZIP"], 220 options: &SnowflakeFileTransferOptions{ 221 MultiPartThreshold: dataSizeThreshold, 222 }, 223 mockAzureClient: &azureObjectAPIMock{ 224 UploadStreamFunc: func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { 225 return azblob.UploadStreamResponse{}, errors.New("unexpected error uploading file") 226 }, 227 }, 228 } 229 230 uploadMeta.realSrcStream = uploadMeta.srcStream 231 232 err = new(remoteStorageUtil).uploadOneFile(&uploadMeta) 233 if err == nil { 234 t.Fatal("should have failed") 235 } 236 } 237 238 func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) { 239 info := execResponseStageInfo{ 240 Location: "azblob/storage/users/456/", 241 LocationType: "AZURE", 242 } 243 initialParallel := int64(100) 244 dir, err := os.Getwd() 245 if err != nil { 246 t.Error(err) 247 } 248 249 dd := &execResponseData{} 250 invalidSig := &execResponse{ 251 Data: *dd, 252 Message: "Signature not valid in the specified time frame", 253 Code: "403", 254 Success: true, 255 } 256 ba, err := json.Marshal(invalidSig) 257 if err != nil { 258 panic(err) 259 } 260 261 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 262 if err != nil { 263 t.Error(err) 264 } 265 uploadMeta := fileMetadata{ 266 name: "data1.txt.gz", 267 stageLocationType: "AZURE", 268 noSleepingTime: true, 269 parallel: initialParallel, 270 client: azureCli, 271 sha256Digest: "123456789abcdef", 272 stageInfo: &info, 273 dstFileName: "data1.txt.gz", 274 srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), 275 overwrite: true, 276 dstCompressionType: compressionTypes["GZIP"], 277 options: &SnowflakeFileTransferOptions{ 278 MultiPartThreshold: dataSizeThreshold, 279 }, 280 mockAzureClient: &azureObjectAPIMock{ 281 UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { 282 return azblob.UploadFileResponse{}, &azcore.ResponseError{ 283 ErrorCode: "12345", 284 StatusCode: 403, 285 RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, 286 } 287 }, 288 }, 289 } 290 291 uploadMeta.realSrcFileName = uploadMeta.srcFileName 292 fi, err := os.Stat(uploadMeta.srcFileName) 293 if err != nil { 294 t.Error(err) 295 } 296 uploadMeta.uploadSize = fi.Size() 297 298 err = new(remoteStorageUtil).uploadOneFile(&uploadMeta) 299 if err != nil { 300 t.Fatal(err) 301 } 302 303 if uploadMeta.resStatus != renewToken { 304 t.Fatalf("expected %v result status, got: %v", 305 renewToken, uploadMeta.resStatus) 306 } 307 } 308 309 func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) { 310 info := execResponseStageInfo{ 311 Location: "azblob/storage/users/456/", 312 LocationType: "AZURE", 313 } 314 initialParallel := int64(100) 315 dir, err := os.Getwd() 316 if err != nil { 317 t.Error(err) 318 } 319 320 dd := &execResponseData{} 321 invalidSig := &execResponse{ 322 Data: *dd, 323 Message: "Server Error", 324 Code: "500", 325 Success: true, 326 } 327 ba, err := json.Marshal(invalidSig) 328 if err != nil { 329 panic(err) 330 } 331 332 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 333 if err != nil { 334 t.Error(err) 335 } 336 uploadMeta := fileMetadata{ 337 name: "data1.txt.gz", 338 stageLocationType: "AZURE", 339 noSleepingTime: false, 340 parallel: initialParallel, 341 client: azureCli, 342 sha256Digest: "123456789abcdef", 343 stageInfo: &info, 344 dstFileName: "data1.txt.gz", 345 srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), 346 overwrite: true, 347 dstCompressionType: compressionTypes["GZIP"], 348 options: &SnowflakeFileTransferOptions{ 349 MultiPartThreshold: dataSizeThreshold, 350 }, 351 mockAzureClient: &azureObjectAPIMock{ 352 UploadFileFunc: func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) { 353 return azblob.UploadFileResponse{}, &azcore.ResponseError{ 354 ErrorCode: "12345", 355 StatusCode: 500, 356 RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, 357 } 358 }, 359 }, 360 } 361 362 uploadMeta.realSrcFileName = uploadMeta.srcFileName 363 fi, err := os.Stat(uploadMeta.srcFileName) 364 if err != nil { 365 t.Error(err) 366 } 367 uploadMeta.uploadSize = fi.Size() 368 369 err = new(remoteStorageUtil).uploadOneFile(&uploadMeta) 370 if err == nil { 371 t.Fatal("should have raised an error") 372 } 373 374 if uploadMeta.resStatus != needRetry { 375 t.Fatalf("expected %v result status, got: %v", 376 needRetry, uploadMeta.resStatus) 377 } 378 } 379 380 func TestDownloadOneFileToAzureFailed(t *testing.T) { 381 info := execResponseStageInfo{ 382 Location: "azblob/rwyitestacco/users/1234/", 383 LocationType: "AZURE", 384 } 385 dir, err := os.Getwd() 386 if err != nil { 387 t.Error(err) 388 } 389 390 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 391 if err != nil { 392 t.Error(err) 393 } 394 395 downloadMeta := fileMetadata{ 396 name: "data1.txt.gz", 397 stageLocationType: "AZURE", 398 noSleepingTime: true, 399 client: azureCli, 400 stageInfo: &info, 401 dstFileName: "data1.txt.gz", 402 overwrite: true, 403 srcFileName: "data1.txt.gz", 404 localLocation: dir, 405 options: &SnowflakeFileTransferOptions{ 406 MultiPartThreshold: dataSizeThreshold, 407 }, 408 mockAzureClient: &azureObjectAPIMock{ 409 DownloadFileFunc: func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) { 410 return 0, errors.New("unexpected error uploading file") 411 }, 412 GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 413 return blob.GetPropertiesResponse{}, nil 414 }, 415 }, 416 } 417 err = new(remoteStorageUtil).downloadOneFile(&downloadMeta) 418 if err == nil { 419 t.Error("should have raised an error") 420 } 421 } 422 423 func TestGetFileHeaderErrorStatus(t *testing.T) { 424 info := execResponseStageInfo{ 425 Location: "azblob/teststage/users/34/", 426 LocationType: "AZURE", 427 } 428 429 azureCli, err := new(snowflakeAzureClient).createClient(&info, false) 430 if err != nil { 431 t.Error(err) 432 } 433 434 meta := fileMetadata{ 435 client: azureCli, 436 stageInfo: &info, 437 mockAzureClient: &azureObjectAPIMock{ 438 GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 439 return blob.GetPropertiesResponse{}, errors.New("failed to retrieve headers") 440 }, 441 }, 442 } 443 444 if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { 445 t.Fatalf("expected null header, got: %v", header) 446 } 447 if meta.resStatus != errStatus { 448 t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) 449 } 450 451 dd := &execResponseData{} 452 invalidSig := &execResponse{ 453 Data: *dd, 454 Message: "Not Found", 455 Code: "404", 456 Success: true, 457 } 458 ba, err := json.Marshal(invalidSig) 459 if err != nil { 460 panic(err) 461 } 462 463 meta = fileMetadata{ 464 client: azureCli, 465 stageInfo: &info, 466 mockAzureClient: &azureObjectAPIMock{ 467 GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 468 return blob.GetPropertiesResponse{}, &azcore.ResponseError{ 469 ErrorCode: "BlobNotFound", 470 StatusCode: 404, 471 RawResponse: &http.Response{StatusCode: http.StatusNotFound, Body: &fakeResponseBody{body: ba}}, 472 } 473 }, 474 }, 475 } 476 477 if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { 478 t.Fatalf("expected null header, got: %v", header) 479 } 480 if meta.resStatus != notFoundFile { 481 t.Fatalf("expected %v result status, got: %v", errStatus, meta.resStatus) 482 } 483 484 invalidSig = &execResponse{ 485 Data: *dd, 486 Message: "Unauthorized", 487 Code: "403", 488 Success: true, 489 } 490 ba, err = json.Marshal(invalidSig) 491 if err != nil { 492 panic(err) 493 } 494 meta.mockAzureClient = &azureObjectAPIMock{ 495 GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 496 return blob.GetPropertiesResponse{}, &azcore.ResponseError{ 497 StatusCode: 403, 498 RawResponse: &http.Response{StatusCode: http.StatusForbidden, Body: &fakeResponseBody{body: ba}}, 499 } 500 }, 501 } 502 503 if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil { 504 t.Fatalf("expected null header, got: %v", header) 505 } 506 if meta.resStatus != renewToken { 507 t.Fatalf("expected %v result status, got: %v", renewToken, meta.resStatus) 508 } 509 } 510 511 func TestUploadFileToAzureClientCastFail(t *testing.T) { 512 info := execResponseStageInfo{ 513 Location: "azblob/rwyi-testacco/users/9220/", 514 LocationType: "AZURE", 515 } 516 dir, err := os.Getwd() 517 if err != nil { 518 t.Error(err) 519 } 520 521 s3Cli, err := new(snowflakeS3Client).createClient(&info, false) 522 if err != nil { 523 t.Error(err) 524 } 525 uploadMeta := fileMetadata{ 526 name: "data1.txt.gz", 527 stageLocationType: "AZURE", 528 noSleepingTime: false, 529 client: s3Cli, 530 sha256Digest: "123456789abcdef", 531 stageInfo: &info, 532 dstFileName: "data1.txt.gz", 533 srcFileName: path.Join(dir, "/test_data/put_get_1.txt"), 534 overwrite: true, 535 options: &SnowflakeFileTransferOptions{ 536 MultiPartThreshold: dataSizeThreshold, 537 }, 538 } 539 540 uploadMeta.realSrcFileName = uploadMeta.srcFileName 541 fi, err := os.Stat(uploadMeta.srcFileName) 542 if err != nil { 543 t.Error(err) 544 } 545 uploadMeta.uploadSize = fi.Size() 546 547 err = new(remoteStorageUtil).uploadOneFile(&uploadMeta) 548 if err == nil { 549 t.Fatal("should have failed") 550 } 551 } 552 553 func TestAzureGetHeaderClientCastFail(t *testing.T) { 554 info := execResponseStageInfo{ 555 Location: "azblob/rwyi-testacco/users/9220/", 556 LocationType: "AZURE", 557 } 558 s3Cli, err := new(snowflakeS3Client).createClient(&info, false) 559 if err != nil { 560 t.Error(err) 561 } 562 563 meta := fileMetadata{ 564 client: s3Cli, 565 stageInfo: &execResponseStageInfo{Location: ""}, 566 mockAzureClient: &azureObjectAPIMock{ 567 GetPropertiesFunc: func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) { 568 return blob.GetPropertiesResponse{}, nil 569 }, 570 }, 571 } 572 573 _, err = new(snowflakeAzureClient).getFileHeader(&meta, "file.txt") 574 if err == nil { 575 t.Fatal("should have failed") 576 } 577 }