github.com/snowflakedb/gosnowflake@v1.9.0/gcs_storage_client.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"strconv"
    13  	"strings"
    14  )
    15  
    16  const (
    17  	gcsMetadataPrefix             = "x-goog-meta-"
    18  	gcsMetadataSfcDigest          = gcsMetadataPrefix + sfcDigest
    19  	gcsMetadataMatdescKey         = gcsMetadataPrefix + "matdesc"
    20  	gcsMetadataEncryptionDataProp = gcsMetadataPrefix + "encryptiondata"
    21  	gcsFileHeaderDigest           = "gcs-file-header-digest"
    22  )
    23  
    24  type snowflakeGcsClient struct {
    25  }
    26  
    27  type gcsLocation struct {
    28  	bucketName string
    29  	path       string
    30  }
    31  
    32  func (util *snowflakeGcsClient) createClient(info *execResponseStageInfo, _ bool) (cloudClient, error) {
    33  	if info.Creds.GcsAccessToken != "" {
    34  		logger.Debug("Using GCS downscoped token")
    35  		return info.Creds.GcsAccessToken, nil
    36  	}
    37  	logger.Debugf("No access token received from GS, using presigned url: %s", info.PresignedURL)
    38  	return "", nil
    39  }
    40  
    41  // cloudUtil implementation
    42  func (util *snowflakeGcsClient) getFileHeader(meta *fileMetadata, filename string) (*fileHeader, error) {
    43  	if meta.resStatus == uploaded || meta.resStatus == downloaded {
    44  		return &fileHeader{
    45  			digest:             meta.gcsFileHeaderDigest,
    46  			contentLength:      meta.gcsFileHeaderContentLength,
    47  			encryptionMetadata: meta.gcsFileHeaderEncryptionMeta,
    48  		}, nil
    49  	}
    50  	if meta.presignedURL != nil {
    51  		meta.resStatus = notFoundFile
    52  	} else {
    53  		URL, err := util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(filename, "/"))
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  		accessToken, ok := meta.client.(string)
    58  		if !ok {
    59  			return nil, fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
    60  		}
    61  		gcsHeaders := map[string]string{
    62  			"Authorization": "Bearer " + accessToken,
    63  		}
    64  
    65  		req, err := http.NewRequest("HEAD", URL.String(), nil)
    66  		if err != nil {
    67  			return nil, err
    68  		}
    69  		for k, v := range gcsHeaders {
    70  			req.Header.Add(k, v)
    71  		}
    72  		client := newGcsClient()
    73  		// for testing only
    74  		if meta.mockGcsClient != nil {
    75  			client = meta.mockGcsClient
    76  		}
    77  		resp, err := client.Do(req)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		if resp.StatusCode != http.StatusOK {
    82  			meta.lastError = fmt.Errorf(resp.Status)
    83  			meta.resStatus = errStatus
    84  			if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 {
    85  				meta.lastError = fmt.Errorf(resp.Status)
    86  				meta.resStatus = needRetry
    87  				return nil, meta.lastError
    88  			}
    89  			if resp.StatusCode == 404 {
    90  				meta.resStatus = notFoundFile
    91  			} else if util.isTokenExpired(resp) {
    92  				meta.lastError = fmt.Errorf(resp.Status)
    93  				meta.resStatus = renewToken
    94  			}
    95  			return nil, meta.lastError
    96  		}
    97  
    98  		digest := resp.Header.Get(gcsMetadataSfcDigest)
    99  		contentLength, err := strconv.Atoi(resp.Header.Get("content-length"))
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		var encryptionMeta *encryptMetadata
   104  		if resp.Header.Get(gcsMetadataEncryptionDataProp) != "" {
   105  			var encryptData *encryptionData
   106  			err := json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData)
   107  			if err != nil {
   108  				logger.Error(err)
   109  			}
   110  			if encryptData != nil {
   111  				encryptionMeta = &encryptMetadata{
   112  					key: encryptData.WrappedContentKey.EncryptionKey,
   113  					iv:  encryptData.ContentEncryptionIV,
   114  				}
   115  				if resp.Header.Get(gcsMetadataMatdescKey) != "" {
   116  					encryptionMeta.matdesc = resp.Header.Get(gcsMetadataMatdescKey)
   117  				}
   118  			}
   119  		}
   120  		meta.resStatus = uploaded
   121  		return &fileHeader{
   122  			digest:             digest,
   123  			contentLength:      int64(contentLength),
   124  			encryptionMetadata: encryptionMeta,
   125  		}, nil
   126  	}
   127  	return nil, nil
   128  }
   129  
   130  type gcsAPI interface {
   131  	Do(req *http.Request) (*http.Response, error)
   132  }
   133  
   134  // cloudUtil implementation
   135  func (util *snowflakeGcsClient) uploadFile(
   136  	dataFile string,
   137  	meta *fileMetadata,
   138  	encryptMeta *encryptMetadata,
   139  	maxConcurrency int,
   140  	multiPartThreshold int64) error {
   141  	uploadURL := meta.presignedURL
   142  	var accessToken string
   143  	var err error
   144  
   145  	if uploadURL == nil {
   146  		uploadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.dstFileName, "/"))
   147  		if err != nil {
   148  			return err
   149  		}
   150  		var ok bool
   151  		accessToken, ok = meta.client.(string)
   152  		if !ok {
   153  			return fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
   154  		}
   155  	}
   156  
   157  	var contentEncoding string
   158  	if meta.dstCompressionType != nil {
   159  		contentEncoding = strings.ToLower(meta.dstCompressionType.name)
   160  	}
   161  
   162  	if contentEncoding == "gzip" {
   163  		contentEncoding = ""
   164  	}
   165  
   166  	gcsHeaders := make(map[string]string)
   167  	gcsHeaders[httpHeaderContentEncoding] = contentEncoding
   168  	gcsHeaders[gcsMetadataSfcDigest] = meta.sha256Digest
   169  	if accessToken != "" {
   170  		gcsHeaders["Authorization"] = "Bearer " + accessToken
   171  	}
   172  
   173  	if encryptMeta != nil {
   174  		encryptData := encryptionData{
   175  			"FullBlob",
   176  			contentKey{
   177  				"symmKey1",
   178  				encryptMeta.key,
   179  				"AES_CBC_256",
   180  			},
   181  			encryptionAgent{
   182  				"1.0",
   183  				"AES_CBC_256",
   184  			},
   185  			encryptMeta.iv,
   186  			keyMetadata{
   187  				"Java 5.3.0",
   188  			},
   189  		}
   190  		b, err := json.Marshal(&encryptData)
   191  		if err != nil {
   192  			return err
   193  		}
   194  		gcsHeaders[gcsMetadataEncryptionDataProp] = string(b)
   195  		gcsHeaders[gcsMetadataMatdescKey] = encryptMeta.matdesc
   196  	}
   197  
   198  	var uploadSrc io.Reader
   199  	if meta.srcStream != nil {
   200  		uploadSrc = meta.srcStream
   201  		if meta.realSrcStream != nil {
   202  			uploadSrc = meta.realSrcStream
   203  		}
   204  	} else {
   205  		uploadSrc, err = os.Open(dataFile)
   206  		if err != nil {
   207  			return err
   208  		}
   209  	}
   210  
   211  	req, err := http.NewRequest("PUT", uploadURL.String(), uploadSrc)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	for k, v := range gcsHeaders {
   216  		req.Header.Add(k, v)
   217  	}
   218  	client := newGcsClient()
   219  	// for testing only
   220  	if meta.mockGcsClient != nil {
   221  		client = meta.mockGcsClient
   222  	}
   223  	resp, err := client.Do(req)
   224  	if err != nil {
   225  		return err
   226  	}
   227  	if resp.StatusCode != http.StatusOK {
   228  		if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 {
   229  			meta.lastError = fmt.Errorf(resp.Status)
   230  			meta.resStatus = needRetry
   231  		} else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil {
   232  			meta.lastError = fmt.Errorf(resp.Status)
   233  			meta.resStatus = renewPresignedURL
   234  		} else if accessToken != "" && util.isTokenExpired(resp) {
   235  			meta.lastError = fmt.Errorf(resp.Status)
   236  			meta.resStatus = renewToken
   237  		} else {
   238  			meta.lastError = fmt.Errorf(resp.Status)
   239  		}
   240  		return meta.lastError
   241  	}
   242  
   243  	if meta.options.putCallback != nil {
   244  		meta.options.putCallback = &snowflakeProgressPercentage{
   245  			filename:        dataFile,
   246  			fileSize:        float64(meta.srcFileSize),
   247  			outputStream:    meta.options.putCallbackOutputStream,
   248  			showProgressBar: meta.options.showProgressBar,
   249  		}
   250  	}
   251  
   252  	meta.dstFileSize = meta.uploadSize
   253  	meta.resStatus = uploaded
   254  
   255  	meta.gcsFileHeaderDigest = gcsHeaders[gcsFileHeaderDigest]
   256  	meta.gcsFileHeaderContentLength = meta.uploadSize
   257  	if err = json.Unmarshal([]byte(gcsHeaders[gcsMetadataEncryptionDataProp]), &encryptMeta); err != nil {
   258  		return err
   259  	}
   260  	meta.gcsFileHeaderEncryptionMeta = encryptMeta
   261  	return nil
   262  }
   263  
   264  // cloudUtil implementation
   265  func (util *snowflakeGcsClient) nativeDownloadFile(
   266  	meta *fileMetadata,
   267  	fullDstFileName string,
   268  	maxConcurrency int64) error {
   269  	downloadURL := meta.presignedURL
   270  	var accessToken string
   271  	var err error
   272  	gcsHeaders := make(map[string]string)
   273  
   274  	if downloadURL == nil || downloadURL.String() == "" {
   275  		downloadURL, err = util.generateFileURL(meta.stageInfo.Location, strings.TrimLeft(meta.srcFileName, "/"))
   276  		if err != nil {
   277  			return err
   278  		}
   279  		var ok bool
   280  		accessToken, ok = meta.client.(string)
   281  		if !ok {
   282  			return fmt.Errorf("interface convertion. expected type string but got %T", meta.client)
   283  		}
   284  		if accessToken != "" {
   285  			gcsHeaders["Authorization"] = "Bearer " + accessToken
   286  		}
   287  	}
   288  
   289  	req, err := http.NewRequest("GET", downloadURL.String(), nil)
   290  	if err != nil {
   291  		return err
   292  	}
   293  	for k, v := range gcsHeaders {
   294  		req.Header.Add(k, v)
   295  	}
   296  	client := newGcsClient()
   297  	// for testing only
   298  	if meta.mockGcsClient != nil {
   299  		client = meta.mockGcsClient
   300  	}
   301  	resp, err := client.Do(req)
   302  	if err != nil {
   303  		return err
   304  	}
   305  	if resp.StatusCode != http.StatusOK {
   306  		if resp.StatusCode == 403 || resp.StatusCode == 408 || resp.StatusCode == 429 || resp.StatusCode == 500 || resp.StatusCode == 503 {
   307  			meta.lastError = fmt.Errorf(resp.Status)
   308  			meta.resStatus = needRetry
   309  		} else if resp.StatusCode == 404 {
   310  			meta.lastError = fmt.Errorf(resp.Status)
   311  			meta.resStatus = notFoundFile
   312  		} else if accessToken == "" && resp.StatusCode == 400 && meta.lastError == nil {
   313  			meta.lastError = fmt.Errorf(resp.Status)
   314  			meta.resStatus = renewPresignedURL
   315  		} else if accessToken != "" && util.isTokenExpired(resp) {
   316  			meta.lastError = fmt.Errorf(resp.Status)
   317  			meta.resStatus = renewToken
   318  		} else {
   319  			meta.lastError = fmt.Errorf(resp.Status)
   320  
   321  		}
   322  		return meta.lastError
   323  	}
   324  
   325  	f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
   326  	if err != nil {
   327  		return err
   328  	}
   329  	defer f.Close()
   330  	if _, err = io.Copy(f, resp.Body); err != nil {
   331  		return err
   332  	}
   333  
   334  	var encryptMeta encryptMetadata
   335  	if resp.Header.Get(gcsMetadataEncryptionDataProp) != "" {
   336  		var encryptData *encryptionData
   337  		if err = json.Unmarshal([]byte(resp.Header.Get(gcsMetadataEncryptionDataProp)), &encryptData); err != nil {
   338  			return err
   339  		}
   340  		if encryptData != nil {
   341  			encryptMeta = encryptMetadata{
   342  				encryptData.WrappedContentKey.EncryptionKey,
   343  				encryptData.ContentEncryptionIV,
   344  				"",
   345  			}
   346  			if key := resp.Header.Get(gcsMetadataMatdescKey); key != "" {
   347  				encryptMeta.matdesc = key
   348  			}
   349  		}
   350  	}
   351  
   352  	fi, err := os.Stat(fullDstFileName)
   353  	if err != nil {
   354  		return err
   355  	}
   356  	meta.srcFileSize = fi.Size()
   357  	meta.resStatus = downloaded
   358  	meta.gcsFileHeaderDigest = resp.Header.Get(gcsMetadataSfcDigest)
   359  	meta.gcsFileHeaderContentLength = resp.ContentLength
   360  	meta.gcsFileHeaderEncryptionMeta = &encryptMeta
   361  	return nil
   362  }
   363  
   364  func (util *snowflakeGcsClient) extractBucketNameAndPath(location string) *gcsLocation {
   365  	containerName := location
   366  	var path string
   367  	if strings.Contains(location, "/") {
   368  		containerName = location[:strings.Index(location, "/")]
   369  		path = location[strings.Index(location, "/")+1:]
   370  		if path != "" && !strings.HasSuffix(path, "/") {
   371  			path += "/"
   372  		}
   373  	}
   374  	return &gcsLocation{containerName, path}
   375  }
   376  
   377  func (util *snowflakeGcsClient) generateFileURL(stageLocation string, filename string) (*url.URL, error) {
   378  	gcsLoc := util.extractBucketNameAndPath(stageLocation)
   379  	fullFilePath := gcsLoc.path + filename
   380  	URL, err := url.Parse("https://storage.googleapis.com/" + gcsLoc.bucketName + "/" + url.QueryEscape(fullFilePath))
   381  	if err != nil {
   382  		return nil, err
   383  	}
   384  	return URL, nil
   385  }
   386  
   387  func (util *snowflakeGcsClient) isTokenExpired(resp *http.Response) bool {
   388  	return resp.StatusCode == 401
   389  }
   390  
   391  func newGcsClient() gcsAPI {
   392  	return &http.Client{
   393  		Transport: SnowflakeTransport,
   394  	}
   395  }