github.com/snowflakedb/gosnowflake@v1.9.0/azure_storage_client.go (about) 1 // Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "net/http" 12 "net/url" 13 "os" 14 "strings" 15 "time" 16 17 "github.com/Azure/azure-sdk-for-go/sdk/azcore" 18 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" 19 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" 20 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" 21 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" 22 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" 23 ) 24 25 type snowflakeAzureClient struct { 26 } 27 28 type azureLocation struct { 29 containerName string 30 path string 31 } 32 33 type azureAPI interface { 34 UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) 35 UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error) 36 DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error) 37 GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error) 38 } 39 40 func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) { 41 sasToken := info.Creds.AzureSasToken 42 u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken)) 43 if err != nil { 44 return nil, err 45 } 46 client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{ 47 ClientOptions: azcore.ClientOptions{ 48 Retry: policy.RetryOptions{ 49 MaxRetries: 60, 50 RetryDelay: 2 * time.Second, 51 }, 52 Transport: &http.Client{ 53 Transport: SnowflakeTransport, 54 }, 55 }, 56 }) 57 if err != nil { 58 return nil, err 59 } 60 return client, nil 61 } 62 63 // cloudUtil implementation 64 func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename string) (*fileHeader, error) { 65 client, ok := meta.client.(*azblob.Client) 66 if !ok { 67 return nil, fmt.Errorf("failed to parse client to azblob.Client") 68 } 69 70 azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) 71 if err != nil { 72 return nil, err 73 } 74 path := azureLoc.path + strings.TrimLeft(filename, "/") 75 containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{}) 76 if err != nil { 77 return nil, &SnowflakeError{ 78 Message: "failed to create container client", 79 } 80 } 81 var blobClient azureAPI 82 blobClient = containerClient.NewBlockBlobClient(path) 83 // for testing only 84 if meta.mockAzureClient != nil { 85 blobClient = meta.mockAzureClient 86 } 87 resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{ 88 AccessConditions: &blob.AccessConditions{}, 89 CPKInfo: &blob.CPKInfo{}, 90 }) 91 if err != nil { 92 var se *azcore.ResponseError 93 if errors.As(err, &se) { 94 if se.ErrorCode == string(bloberror.BlobNotFound) { 95 meta.resStatus = notFoundFile 96 return nil, fmt.Errorf("could not find file") 97 } else if se.StatusCode == 403 { 98 meta.resStatus = renewToken 99 return nil, fmt.Errorf("received 403, attempting to renew") 100 } 101 } 102 meta.resStatus = errStatus 103 return nil, err 104 } 105 106 meta.resStatus = uploaded 107 metadata := resp.Metadata 108 var encData encryptionData 109 110 _, ok = metadata["Encryptiondata"] 111 if ok { 112 if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil { 113 return nil, err 114 } 115 } 116 117 matdesc, ok := metadata["Matdesc"] 118 if !ok { 119 // matdesc is not in response, use empty string 120 matdesc = new(string) 121 } 122 encryptionMetadata := encryptMetadata{ 123 encData.WrappedContentKey.EncryptionKey, 124 encData.ContentEncryptionIV, 125 *matdesc, 126 } 127 128 digest, ok := metadata["Sfcdigest"] 129 if !ok { 130 // sfcdigest is not in response, use empty string 131 digest = new(string) 132 } 133 return &fileHeader{ 134 *digest, 135 int64(len(metadata)), 136 &encryptionMetadata, 137 }, nil 138 } 139 140 // cloudUtil implementation 141 func (util *snowflakeAzureClient) uploadFile( 142 dataFile string, 143 meta *fileMetadata, 144 encryptMeta *encryptMetadata, 145 maxConcurrency int, 146 multiPartThreshold int64) error { 147 azureMeta := map[string]*string{ 148 "sfcdigest": &meta.sha256Digest, 149 } 150 if encryptMeta != nil { 151 ed := &encryptionData{ 152 EncryptionMode: "FullBlob", 153 WrappedContentKey: contentKey{ 154 "symmKey1", 155 encryptMeta.key, 156 "AES_CBC_256", 157 }, 158 EncryptionAgent: encryptionAgent{ 159 "1.0", 160 "AES_CBC_128", 161 }, 162 ContentEncryptionIV: encryptMeta.iv, 163 KeyWrappingMetadata: keyMetadata{ 164 "Java 5.3.0", 165 }, 166 } 167 metadata, err := json.Marshal(ed) 168 if err != nil { 169 return err 170 } 171 encryptionMetadata := string(metadata) 172 azureMeta["encryptiondata"] = &encryptionMetadata 173 azureMeta["matdesc"] = &encryptMeta.matdesc 174 } 175 176 azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) 177 if err != nil { 178 return err 179 } 180 path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/") 181 client, ok := meta.client.(*azblob.Client) 182 if !ok { 183 return &SnowflakeError{ 184 Message: "failed to cast to azure client", 185 } 186 } 187 containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{}) 188 if err != nil { 189 return &SnowflakeError{ 190 Message: "failed to create container client", 191 } 192 } 193 var blobClient azureAPI 194 blobClient = containerClient.NewBlockBlobClient(path) 195 // for testing only 196 if meta.mockAzureClient != nil { 197 blobClient = meta.mockAzureClient 198 } 199 if meta.srcStream != nil { 200 uploadSrc := meta.srcStream 201 if meta.realSrcStream != nil { 202 uploadSrc = meta.realSrcStream 203 } 204 _, err = blobClient.UploadStream(context.Background(), uploadSrc, &azblob.UploadStreamOptions{ 205 BlockSize: int64(uploadSrc.Len()), 206 Metadata: azureMeta, 207 }) 208 } else { 209 var f *os.File 210 f, err = os.Open(dataFile) 211 if err != nil { 212 return err 213 } 214 defer f.Close() 215 216 contentType := "application/octet-stream" 217 contentEncoding := "utf-8" 218 blobOptions := &azblob.UploadFileOptions{ 219 HTTPHeaders: &blob.HTTPHeaders{ 220 BlobContentType: &contentType, 221 BlobContentEncoding: &contentEncoding, 222 }, 223 Metadata: azureMeta, 224 Concurrency: uint16(maxConcurrency), 225 } 226 if meta.options.putAzureCallback != nil { 227 blobOptions.Progress = meta.options.putAzureCallback.call 228 } 229 _, err = blobClient.UploadFile(context.Background(), f, blobOptions) 230 } 231 if err != nil { 232 var se *azcore.ResponseError 233 if errors.As(err, &se) { 234 if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse) { 235 meta.resStatus = renewToken 236 } else { 237 meta.resStatus = needRetry 238 meta.lastError = err 239 } 240 return err 241 } 242 meta.resStatus = errStatus 243 return err 244 } 245 246 meta.dstFileSize = meta.uploadSize 247 meta.resStatus = uploaded 248 return nil 249 } 250 251 // cloudUtil implementation 252 func (util *snowflakeAzureClient) nativeDownloadFile( 253 meta *fileMetadata, 254 fullDstFileName string, 255 maxConcurrency int64) error { 256 azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location) 257 if err != nil { 258 return err 259 } 260 path := azureLoc.path + strings.TrimLeft(meta.srcFileName, "/") 261 client, ok := meta.client.(*azblob.Client) 262 if !ok { 263 return &SnowflakeError{ 264 Message: "failed to cast to azure client", 265 } 266 } 267 containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{}) 268 if err != nil { 269 return &SnowflakeError{ 270 Message: "failed to create container client", 271 } 272 } 273 var blobClient azureAPI 274 blobClient = containerClient.NewBlockBlobClient(path) 275 // for testing only 276 if meta.mockAzureClient != nil { 277 blobClient = meta.mockAzureClient 278 } 279 f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode) 280 if err != nil { 281 return err 282 } 283 defer f.Close() 284 _, err = blobClient.DownloadFile( 285 context.Background(), f, &azblob.DownloadFileOptions{ 286 Concurrency: uint16(maxConcurrency)}) 287 if err != nil { 288 return err 289 } 290 meta.resStatus = downloaded 291 return nil 292 } 293 294 func (util *snowflakeAzureClient) extractContainerNameAndPath(location string) (*azureLocation, error) { 295 stageLocation, err := expandUser(location) 296 if err != nil { 297 return nil, err 298 } 299 containerName := stageLocation 300 path := "" 301 302 if strings.Contains(stageLocation, "/") { 303 containerName = stageLocation[:strings.Index(stageLocation, "/")] 304 path = stageLocation[strings.Index(stageLocation, "/")+1:] 305 if path != "" && !strings.HasSuffix(path, "/") { 306 path += "/" 307 } 308 } 309 return &azureLocation{containerName, path}, nil 310 } 311 312 func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Response) bool { 313 if resp.StatusCode != 403 { 314 return false 315 } 316 azureErr, err := io.ReadAll(resp.Body) 317 if err != nil { 318 return false 319 } 320 errStr := string(azureErr) 321 return strings.Contains(errStr, "Signature not valid in the specified time frame") || 322 strings.Contains(errStr, "Server failed to authenticate the request") 323 }