github.com/snowflakedb/gosnowflake@v1.9.0/azure_storage_client.go (about)

     1  // Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"os"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
    18  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
    19  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
    20  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
    21  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
    22  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
    23  )
    24  
    25  type snowflakeAzureClient struct {
    26  }
    27  
    28  type azureLocation struct {
    29  	containerName string
    30  	path          string
    31  }
    32  
    33  type azureAPI interface {
    34  	UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
    35  	UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
    36  	DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
    37  	GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
    38  }
    39  
    40  func (util *snowflakeAzureClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
    41  	sasToken := info.Creds.AzureSasToken
    42  	u, err := url.Parse(fmt.Sprintf("https://%s.%s/%s%s", info.StorageAccount, info.EndPoint, info.Path, sasToken))
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	client, err := azblob.NewClientWithNoCredential(u.String(), &azblob.ClientOptions{
    47  		ClientOptions: azcore.ClientOptions{
    48  			Retry: policy.RetryOptions{
    49  				MaxRetries: 60,
    50  				RetryDelay: 2 * time.Second,
    51  			},
    52  			Transport: &http.Client{
    53  				Transport: SnowflakeTransport,
    54  			},
    55  		},
    56  	})
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	return client, nil
    61  }
    62  
    63  // cloudUtil implementation
    64  func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename string) (*fileHeader, error) {
    65  	client, ok := meta.client.(*azblob.Client)
    66  	if !ok {
    67  		return nil, fmt.Errorf("failed to parse client to azblob.Client")
    68  	}
    69  
    70  	azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	path := azureLoc.path + strings.TrimLeft(filename, "/")
    75  	containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{})
    76  	if err != nil {
    77  		return nil, &SnowflakeError{
    78  			Message: "failed to create container client",
    79  		}
    80  	}
    81  	var blobClient azureAPI
    82  	blobClient = containerClient.NewBlockBlobClient(path)
    83  	// for testing only
    84  	if meta.mockAzureClient != nil {
    85  		blobClient = meta.mockAzureClient
    86  	}
    87  	resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{
    88  		AccessConditions: &blob.AccessConditions{},
    89  		CPKInfo:          &blob.CPKInfo{},
    90  	})
    91  	if err != nil {
    92  		var se *azcore.ResponseError
    93  		if errors.As(err, &se) {
    94  			if se.ErrorCode == string(bloberror.BlobNotFound) {
    95  				meta.resStatus = notFoundFile
    96  				return nil, fmt.Errorf("could not find file")
    97  			} else if se.StatusCode == 403 {
    98  				meta.resStatus = renewToken
    99  				return nil, fmt.Errorf("received 403, attempting to renew")
   100  			}
   101  		}
   102  		meta.resStatus = errStatus
   103  		return nil, err
   104  	}
   105  
   106  	meta.resStatus = uploaded
   107  	metadata := resp.Metadata
   108  	var encData encryptionData
   109  
   110  	_, ok = metadata["Encryptiondata"]
   111  	if ok {
   112  		if err = json.Unmarshal([]byte(*metadata["Encryptiondata"]), &encData); err != nil {
   113  			return nil, err
   114  		}
   115  	}
   116  
   117  	matdesc, ok := metadata["Matdesc"]
   118  	if !ok {
   119  		// matdesc is not in response, use empty string
   120  		matdesc = new(string)
   121  	}
   122  	encryptionMetadata := encryptMetadata{
   123  		encData.WrappedContentKey.EncryptionKey,
   124  		encData.ContentEncryptionIV,
   125  		*matdesc,
   126  	}
   127  
   128  	digest, ok := metadata["Sfcdigest"]
   129  	if !ok {
   130  		// sfcdigest is not in response, use empty string
   131  		digest = new(string)
   132  	}
   133  	return &fileHeader{
   134  		*digest,
   135  		int64(len(metadata)),
   136  		&encryptionMetadata,
   137  	}, nil
   138  }
   139  
   140  // cloudUtil implementation
   141  func (util *snowflakeAzureClient) uploadFile(
   142  	dataFile string,
   143  	meta *fileMetadata,
   144  	encryptMeta *encryptMetadata,
   145  	maxConcurrency int,
   146  	multiPartThreshold int64) error {
   147  	azureMeta := map[string]*string{
   148  		"sfcdigest": &meta.sha256Digest,
   149  	}
   150  	if encryptMeta != nil {
   151  		ed := &encryptionData{
   152  			EncryptionMode: "FullBlob",
   153  			WrappedContentKey: contentKey{
   154  				"symmKey1",
   155  				encryptMeta.key,
   156  				"AES_CBC_256",
   157  			},
   158  			EncryptionAgent: encryptionAgent{
   159  				"1.0",
   160  				"AES_CBC_128",
   161  			},
   162  			ContentEncryptionIV: encryptMeta.iv,
   163  			KeyWrappingMetadata: keyMetadata{
   164  				"Java 5.3.0",
   165  			},
   166  		}
   167  		metadata, err := json.Marshal(ed)
   168  		if err != nil {
   169  			return err
   170  		}
   171  		encryptionMetadata := string(metadata)
   172  		azureMeta["encryptiondata"] = &encryptionMetadata
   173  		azureMeta["matdesc"] = &encryptMeta.matdesc
   174  	}
   175  
   176  	azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location)
   177  	if err != nil {
   178  		return err
   179  	}
   180  	path := azureLoc.path + strings.TrimLeft(meta.dstFileName, "/")
   181  	client, ok := meta.client.(*azblob.Client)
   182  	if !ok {
   183  		return &SnowflakeError{
   184  			Message: "failed to cast to azure client",
   185  		}
   186  	}
   187  	containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{})
   188  	if err != nil {
   189  		return &SnowflakeError{
   190  			Message: "failed to create container client",
   191  		}
   192  	}
   193  	var blobClient azureAPI
   194  	blobClient = containerClient.NewBlockBlobClient(path)
   195  	// for testing only
   196  	if meta.mockAzureClient != nil {
   197  		blobClient = meta.mockAzureClient
   198  	}
   199  	if meta.srcStream != nil {
   200  		uploadSrc := meta.srcStream
   201  		if meta.realSrcStream != nil {
   202  			uploadSrc = meta.realSrcStream
   203  		}
   204  		_, err = blobClient.UploadStream(context.Background(), uploadSrc, &azblob.UploadStreamOptions{
   205  			BlockSize: int64(uploadSrc.Len()),
   206  			Metadata:  azureMeta,
   207  		})
   208  	} else {
   209  		var f *os.File
   210  		f, err = os.Open(dataFile)
   211  		if err != nil {
   212  			return err
   213  		}
   214  		defer f.Close()
   215  
   216  		contentType := "application/octet-stream"
   217  		contentEncoding := "utf-8"
   218  		blobOptions := &azblob.UploadFileOptions{
   219  			HTTPHeaders: &blob.HTTPHeaders{
   220  				BlobContentType:     &contentType,
   221  				BlobContentEncoding: &contentEncoding,
   222  			},
   223  			Metadata:    azureMeta,
   224  			Concurrency: uint16(maxConcurrency),
   225  		}
   226  		if meta.options.putAzureCallback != nil {
   227  			blobOptions.Progress = meta.options.putAzureCallback.call
   228  		}
   229  		_, err = blobClient.UploadFile(context.Background(), f, blobOptions)
   230  	}
   231  	if err != nil {
   232  		var se *azcore.ResponseError
   233  		if errors.As(err, &se) {
   234  			if se.StatusCode == 403 && util.detectAzureTokenExpireError(se.RawResponse) {
   235  				meta.resStatus = renewToken
   236  			} else {
   237  				meta.resStatus = needRetry
   238  				meta.lastError = err
   239  			}
   240  			return err
   241  		}
   242  		meta.resStatus = errStatus
   243  		return err
   244  	}
   245  
   246  	meta.dstFileSize = meta.uploadSize
   247  	meta.resStatus = uploaded
   248  	return nil
   249  }
   250  
   251  // cloudUtil implementation
   252  func (util *snowflakeAzureClient) nativeDownloadFile(
   253  	meta *fileMetadata,
   254  	fullDstFileName string,
   255  	maxConcurrency int64) error {
   256  	azureLoc, err := util.extractContainerNameAndPath(meta.stageInfo.Location)
   257  	if err != nil {
   258  		return err
   259  	}
   260  	path := azureLoc.path + strings.TrimLeft(meta.srcFileName, "/")
   261  	client, ok := meta.client.(*azblob.Client)
   262  	if !ok {
   263  		return &SnowflakeError{
   264  			Message: "failed to cast to azure client",
   265  		}
   266  	}
   267  	containerClient, err := container.NewClientWithNoCredential(client.URL(), &container.ClientOptions{})
   268  	if err != nil {
   269  		return &SnowflakeError{
   270  			Message: "failed to create container client",
   271  		}
   272  	}
   273  	var blobClient azureAPI
   274  	blobClient = containerClient.NewBlockBlobClient(path)
   275  	// for testing only
   276  	if meta.mockAzureClient != nil {
   277  		blobClient = meta.mockAzureClient
   278  	}
   279  	f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	defer f.Close()
   284  	_, err = blobClient.DownloadFile(
   285  		context.Background(), f, &azblob.DownloadFileOptions{
   286  			Concurrency: uint16(maxConcurrency)})
   287  	if err != nil {
   288  		return err
   289  	}
   290  	meta.resStatus = downloaded
   291  	return nil
   292  }
   293  
   294  func (util *snowflakeAzureClient) extractContainerNameAndPath(location string) (*azureLocation, error) {
   295  	stageLocation, err := expandUser(location)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	containerName := stageLocation
   300  	path := ""
   301  
   302  	if strings.Contains(stageLocation, "/") {
   303  		containerName = stageLocation[:strings.Index(stageLocation, "/")]
   304  		path = stageLocation[strings.Index(stageLocation, "/")+1:]
   305  		if path != "" && !strings.HasSuffix(path, "/") {
   306  			path += "/"
   307  		}
   308  	}
   309  	return &azureLocation{containerName, path}, nil
   310  }
   311  
   312  func (util *snowflakeAzureClient) detectAzureTokenExpireError(resp *http.Response) bool {
   313  	if resp.StatusCode != 403 {
   314  		return false
   315  	}
   316  	azureErr, err := io.ReadAll(resp.Body)
   317  	if err != nil {
   318  		return false
   319  	}
   320  	errStr := string(azureErr)
   321  	return strings.Contains(errStr, "Signature not valid in the specified time frame") ||
   322  		strings.Contains(errStr, "Server failed to authenticate the request")
   323  }