github.com/snowflakedb/gosnowflake@v1.9.0/encrypt_util.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"crypto/aes"
     8  	"crypto/cipher"
     9  	"crypto/rand"
    10  	"encoding/base64"
    11  	"encoding/json"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strconv"
    16  )
    17  
    18  type snowflakeFileEncryption struct {
    19  	QueryStageMasterKey string `json:"queryStageMasterKey,omitempty"`
    20  	QueryID             string `json:"queryId,omitempty"`
    21  	SMKID               int64  `json:"smkId,omitempty"`
    22  }
    23  
    24  // PUT requests return a single encryptionMaterial object whereas GET requests
    25  // return a slice (array) of encryptionMaterial objects, both under the field
    26  // 'encryptionMaterial'
    27  type encryptionWrapper struct {
    28  	snowflakeFileEncryption
    29  	EncryptionMaterials []snowflakeFileEncryption
    30  }
    31  
    32  // override default behavior for wrapper
    33  func (ew *encryptionWrapper) UnmarshalJSON(data []byte) error {
    34  	// if GET, unmarshal slice of encryptionMaterial
    35  	if err := json.Unmarshal(data, &ew.EncryptionMaterials); err == nil {
    36  		return err
    37  	}
    38  	// else (if PUT), unmarshal the encryptionMaterial itself
    39  	return json.Unmarshal(data, &ew.snowflakeFileEncryption)
    40  }
    41  
    42  type encryptMetadata struct {
    43  	key     string
    44  	iv      string
    45  	matdesc string
    46  }
    47  
    48  // encryptStream encrypts a stream buffer using AES128 block cipher in CBC mode
    49  // with PKCS5 padding
    50  func encryptStream(
    51  	sfe *snowflakeFileEncryption,
    52  	src io.Reader,
    53  	out io.Writer,
    54  	chunkSize int) (*encryptMetadata, error) {
    55  	if chunkSize == 0 {
    56  		chunkSize = aes.BlockSize * 4 * 1024
    57  	}
    58  	decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	keySize := len(decodedKey)
    63  
    64  	fileKey := getSecureRandom(keySize)
    65  	block, err := aes.NewCipher(fileKey)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	ivData := getSecureRandom(block.BlockSize())
    70  
    71  	mode := cipher.NewCBCEncrypter(block, ivData)
    72  	cipherText := make([]byte, chunkSize)
    73  	chunk := make([]byte, chunkSize)
    74  
    75  	// encrypt file with CBC
    76  	padded := false
    77  	for {
    78  		// read the stream buffer up to len(chunk) bytes into chunk
    79  		// note that all spaces in chunk may be used even if Read() returns n < len(chunk)
    80  		n, err := src.Read(chunk)
    81  		if n == 0 || err != nil {
    82  			break
    83  		} else if n%aes.BlockSize != 0 {
    84  			// add padding to the end of the chunk and update the length n
    85  			chunk = padBytesLength(chunk[:n], aes.BlockSize)
    86  			n = len(chunk)
    87  			padded = true
    88  		}
    89  		// make sure only n bytes of chunk is used
    90  		mode.CryptBlocks(cipherText, chunk[:n])
    91  		out.Write(cipherText[:n])
    92  	}
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	// add padding if not yet added
    98  	if !padded {
    99  		padding := bytes.Repeat([]byte(string(rune(aes.BlockSize))), aes.BlockSize)
   100  		mode.CryptBlocks(cipherText, padding)
   101  		out.Write(cipherText[:len(padding)])
   102  	}
   103  
   104  	// encrypt key with ECB
   105  	fileKey = padBytesLength(fileKey, block.BlockSize())
   106  	encryptedFileKey := make([]byte, len(fileKey))
   107  	if err = encryptECB(encryptedFileKey, fileKey, decodedKey); err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	matDesc := materialDescriptor{
   112  		strconv.Itoa(int(sfe.SMKID)),
   113  		sfe.QueryID,
   114  		strconv.Itoa(keySize * 8),
   115  	}
   116  
   117  	matDescUnicode, err := matdescToUnicode(matDesc)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	return &encryptMetadata{
   122  		base64.StdEncoding.EncodeToString(encryptedFileKey),
   123  		base64.StdEncoding.EncodeToString(ivData),
   124  		matDescUnicode,
   125  	}, nil
   126  }
   127  
   128  func encryptECB(encrypted []byte, fileKey []byte, decodedKey []byte) error {
   129  	block, err := aes.NewCipher(decodedKey)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	if len(fileKey)%block.BlockSize() != 0 {
   134  		return fmt.Errorf("input not full of blocks")
   135  	}
   136  	if len(encrypted) < len(fileKey) {
   137  		return fmt.Errorf("output length is smaller than input length")
   138  	}
   139  	for len(fileKey) > 0 {
   140  		block.Encrypt(encrypted, fileKey[:block.BlockSize()])
   141  		encrypted = encrypted[block.BlockSize():]
   142  		fileKey = fileKey[block.BlockSize():]
   143  	}
   144  	return nil
   145  }
   146  
   147  func decryptECB(decrypted []byte, keyBytes []byte, decodedKey []byte) error {
   148  	block, err := aes.NewCipher(decodedKey)
   149  	if err != nil {
   150  		return err
   151  	}
   152  	if len(keyBytes)%block.BlockSize() != 0 {
   153  		return fmt.Errorf("input not full of blocks")
   154  	}
   155  	if len(decrypted) < len(keyBytes) {
   156  		return fmt.Errorf("output length is smaller than input length")
   157  	}
   158  	for len(keyBytes) > 0 {
   159  		block.Decrypt(decrypted, keyBytes[:block.BlockSize()])
   160  		keyBytes = keyBytes[block.BlockSize():]
   161  		decrypted = decrypted[block.BlockSize():]
   162  	}
   163  	return nil
   164  }
   165  
   166  func encryptFile(
   167  	sfe *snowflakeFileEncryption,
   168  	filename string,
   169  	chunkSize int,
   170  	tmpDir string) (
   171  	*encryptMetadata, string, error) {
   172  	if chunkSize == 0 {
   173  		chunkSize = aes.BlockSize * 4 * 1024
   174  	}
   175  	tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#")
   176  	if err != nil {
   177  		return nil, "", err
   178  	}
   179  	defer tmpOutputFile.Close()
   180  	infile, err := os.OpenFile(filename, os.O_CREATE|os.O_RDONLY, readWriteFileMode)
   181  	if err != nil {
   182  		return nil, "", err
   183  	}
   184  	defer infile.Close()
   185  
   186  	meta, err := encryptStream(sfe, infile, tmpOutputFile, chunkSize)
   187  	if err != nil {
   188  		return nil, "", err
   189  	}
   190  	return meta, tmpOutputFile.Name(), nil
   191  }
   192  
   193  func decryptFile(
   194  	metadata *encryptMetadata,
   195  	sfe *snowflakeFileEncryption,
   196  	filename string,
   197  	chunkSize int,
   198  	tmpDir string) (
   199  	string, error) {
   200  	if chunkSize == 0 {
   201  		chunkSize = aes.BlockSize * 4 * 1024
   202  	}
   203  	decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
   204  	if err != nil {
   205  		return "", err
   206  	}
   207  	keyBytes, err := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key
   208  	if err != nil {
   209  		return "", err
   210  	}
   211  	ivBytes, err := base64.StdEncoding.DecodeString(metadata.iv)
   212  	if err != nil {
   213  		return "", err
   214  	}
   215  
   216  	// decrypt file key
   217  	decryptedKey := make([]byte, len(keyBytes))
   218  	if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil {
   219  		return "", err
   220  	}
   221  	decryptedKey, err = paddingTrim(decryptedKey)
   222  	if err != nil {
   223  		return "", err
   224  	}
   225  
   226  	// decrypt file
   227  	block, err := aes.NewCipher(decryptedKey)
   228  	if err != nil {
   229  		return "", err
   230  	}
   231  	mode := cipher.NewCBCDecrypter(block, ivBytes)
   232  
   233  	tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#")
   234  	if err != nil {
   235  		return "", err
   236  	}
   237  	defer tmpOutputFile.Close()
   238  	infile, err := os.Open(filename)
   239  	if err != nil {
   240  		return "", err
   241  	}
   242  	defer infile.Close()
   243  	var totalFileSize int
   244  	var prevChunk []byte
   245  	for {
   246  		chunk := make([]byte, chunkSize)
   247  		n, err := infile.Read(chunk)
   248  		if n == 0 || err != nil {
   249  			break
   250  		} else if n%aes.BlockSize != 0 {
   251  			// add padding to the end of the chunk and update the length n
   252  			chunk = padBytesLength(chunk[:n], aes.BlockSize)
   253  			n = len(chunk)
   254  		}
   255  		totalFileSize += n
   256  		chunk = chunk[:n]
   257  		mode.CryptBlocks(chunk, chunk)
   258  		tmpOutputFile.Write(chunk)
   259  		prevChunk = chunk
   260  	}
   261  	if err != nil {
   262  		return "", err
   263  	}
   264  	if prevChunk != nil {
   265  		totalFileSize -= paddingOffset(prevChunk)
   266  	}
   267  	tmpOutputFile.Truncate(int64(totalFileSize))
   268  	return tmpOutputFile.Name(), nil
   269  }
   270  
   271  type materialDescriptor struct {
   272  	SmkID   string `json:"smkId"`
   273  	QueryID string `json:"queryId"`
   274  	KeySize string `json:"keySize"`
   275  }
   276  
   277  func matdescToUnicode(matdesc materialDescriptor) (string, error) {
   278  	s, err := json.Marshal(&matdesc)
   279  	if err != nil {
   280  		return "", err
   281  	}
   282  	return string(s), nil
   283  }
   284  
   285  func getSecureRandom(byteLength int) []byte {
   286  	token := make([]byte, byteLength)
   287  	rand.Read(token)
   288  	return token
   289  }
   290  
   291  func padBytesLength(src []byte, blockSize int) []byte {
   292  	padLength := blockSize - len(src)%blockSize
   293  	padText := bytes.Repeat([]byte{byte(padLength)}, padLength)
   294  	return append(src, padText...)
   295  }
   296  
   297  func paddingTrim(src []byte) ([]byte, error) {
   298  	unpadding := src[len(src)-1]
   299  	n := int(unpadding)
   300  	if n == 0 || n > len(src) {
   301  		return nil, &SnowflakeError{
   302  			Number:  ErrInvalidPadding,
   303  			Message: errMsgInvalidPadding,
   304  		}
   305  	}
   306  	return src[:len(src)-n], nil
   307  }
   308  
   309  func paddingOffset(src []byte) int {
   310  	length := len(src)
   311  	return int(src[length-1])
   312  }
   313  
   314  type contentKey struct {
   315  	KeyID         string `json:"KeyId,omitempty"`
   316  	EncryptionKey string `json:"EncryptedKey,omitempty"`
   317  	Algorithm     string `json:"Algorithm,omitempty"`
   318  }
   319  
   320  type encryptionAgent struct {
   321  	Protocol            string `json:"Protocol,omitempty"`
   322  	EncryptionAlgorithm string `json:"EncryptionAlgorithm,omitempty"`
   323  }
   324  
   325  type keyMetadata struct {
   326  	EncryptionLibrary string `json:"EncryptionLibrary,omitempty"`
   327  }
   328  
   329  type encryptionData struct {
   330  	EncryptionMode      string          `json:"EncryptionMode,omitempty"`
   331  	WrappedContentKey   contentKey      `json:"WrappedContentKey,omitempty"`
   332  	EncryptionAgent     encryptionAgent `json:"EncryptionAgent,omitempty"`
   333  	ContentEncryptionIV string          `json:"ContentEncryptionIV,omitempty"`
   334  	KeyWrappingMetadata keyMetadata     `json:"KeyWrappingMetadata,omitempty"`
   335  }