github.com/snowflakedb/gosnowflake@v1.9.0/encrypt_util_test.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bufio"
     7  	"compress/gzip"
     8  	"fmt"
     9  	"io"
    10  	"math/rand"
    11  	"os"
    12  	"os/exec"
    13  	"path"
    14  	"path/filepath"
    15  	"strconv"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  const timeFormat = "2006-01-02T15:04:05"
    21  
    22  type encryptDecryptTestFile struct {
    23  	numberOfBytesInEachRow int
    24  	numberOfLines          int
    25  }
    26  
    27  func TestEncryptDecryptFile(t *testing.T) {
    28  	encMat := snowflakeFileEncryption{
    29  		"ztke8tIdVt1zmlQIZm0BMA==",
    30  		"123873c7-3a66-40c4-ab89-e3722fbccce1",
    31  		3112,
    32  	}
    33  	data := "test data"
    34  	inputFile := "test_encrypt_decrypt_file"
    35  
    36  	fd, err := os.Create(inputFile)
    37  	if err != nil {
    38  		t.Error(err)
    39  	}
    40  	defer fd.Close()
    41  	defer os.Remove(inputFile)
    42  	if _, err = fd.Write([]byte(data)); err != nil {
    43  		t.Error(err)
    44  	}
    45  
    46  	metadata, encryptedFile, err := encryptFile(&encMat, inputFile, 0, "")
    47  	if err != nil {
    48  		t.Error(err)
    49  	}
    50  	defer os.Remove(encryptedFile)
    51  	decryptedFile, err := decryptFile(metadata, &encMat, encryptedFile, 0, "")
    52  	if err != nil {
    53  		t.Error(err)
    54  	}
    55  	defer os.Remove(decryptedFile)
    56  
    57  	fd, err = os.Open(decryptedFile)
    58  	if err != nil {
    59  		t.Error(err)
    60  	}
    61  	defer fd.Close()
    62  	content, err := io.ReadAll(fd)
    63  	if err != nil {
    64  		t.Error(err)
    65  	}
    66  	if string(content) != data {
    67  		t.Fatalf("data did not match content. expected: %v, got: %v", data, string(content))
    68  	}
    69  }
    70  
    71  func TestEncryptDecryptFilePadding(t *testing.T) {
    72  	encMat := snowflakeFileEncryption{
    73  		"ztke8tIdVt1zmlQIZm0BMA==",
    74  		"123873c7-3a66-40c4-ab89-e3722fbccce1",
    75  		3112,
    76  	}
    77  
    78  	testcases := []encryptDecryptTestFile{
    79  		// File size is a multiple of 65536 bytes (chunkSize)
    80  		{numberOfBytesInEachRow: 8, numberOfLines: 16384},
    81  		{numberOfBytesInEachRow: 16, numberOfLines: 4096},
    82  		// File size is not a multiple of 65536 bytes (chunkSize)
    83  		{numberOfBytesInEachRow: 8, numberOfLines: 10240},
    84  		{numberOfBytesInEachRow: 16, numberOfLines: 6144},
    85  		// The second chunk's size is a multiple of 16 bytes (aes.BlockSize)
    86  		{numberOfBytesInEachRow: 16, numberOfLines: 4097},
    87  		// The second chunk's size is not a multiple of 16 bytes (aes.BlockSize)
    88  		{numberOfBytesInEachRow: 12, numberOfLines: 5462},
    89  		{numberOfBytesInEachRow: 10, numberOfLines: 6556},
    90  	}
    91  
    92  	for _, test := range testcases {
    93  		t.Run(fmt.Sprintf("%v_%v", test.numberOfBytesInEachRow, test.numberOfLines), func(t *testing.T) {
    94  			tmpDir, err := generateKLinesOfNByteRows(test.numberOfLines, test.numberOfBytesInEachRow, t.TempDir())
    95  			if err != nil {
    96  				t.Error(err)
    97  			}
    98  
    99  			encryptDecryptFile(t, encMat, test.numberOfLines, tmpDir)
   100  		})
   101  	}
   102  }
   103  
   104  func TestEncryptDecryptLargeFile(t *testing.T) {
   105  	encMat := snowflakeFileEncryption{
   106  		"ztke8tIdVt1zmlQIZm0BMA==",
   107  		"123873c7-3a66-40c4-ab89-e3722fbccce1",
   108  		3112,
   109  	}
   110  
   111  	numberOfFiles := 1
   112  	numberOfLines := 10000
   113  	tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir())
   114  	if err != nil {
   115  		t.Error(err)
   116  	}
   117  
   118  	encryptDecryptFile(t, encMat, numberOfLines, tmpDir)
   119  }
   120  
   121  func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected int, tmpDir string) {
   122  	files, err := filepath.Glob(filepath.Join(tmpDir, "file*"))
   123  	if err != nil {
   124  		t.Error(err)
   125  	}
   126  	inputFile := files[0]
   127  
   128  	metadata, encryptedFile, err := encryptFile(&encMat, inputFile, 0, tmpDir)
   129  	if err != nil {
   130  		t.Error(err)
   131  	}
   132  	defer os.Remove(encryptedFile)
   133  	decryptedFile, err := decryptFile(metadata, &encMat, encryptedFile, 0, tmpDir)
   134  	if err != nil {
   135  		t.Error(err)
   136  	}
   137  	defer os.Remove(decryptedFile)
   138  
   139  	cnt := 0
   140  	fd, err := os.Open(decryptedFile)
   141  	if err != nil {
   142  		t.Error(err)
   143  	}
   144  	defer fd.Close()
   145  
   146  	scanner := bufio.NewScanner(fd)
   147  	for scanner.Scan() {
   148  		cnt++
   149  	}
   150  	if err = scanner.Err(); err != nil {
   151  		t.Error(err)
   152  	}
   153  	if cnt != expected {
   154  		t.Fatalf("incorrect number of lines. expected: %v, got: %v", expected, cnt)
   155  	}
   156  }
   157  
   158  func generateKLinesOfNByteRows(numLines int, numBytes int, tmpDir string) (string, error) {
   159  	fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(numLines*numBytes), 10))
   160  	f, err := os.Create(fname)
   161  	if err != nil {
   162  		return "", err
   163  	}
   164  
   165  	for j := 0; j < numLines; j++ {
   166  		str := randomString(numBytes - 1) // \n is the last character
   167  		rec := fmt.Sprintf("%v\n", str)
   168  		f.Write([]byte(rec))
   169  	}
   170  	f.Close()
   171  	return tmpDir, nil
   172  }
   173  
   174  func generateKLinesOfNFiles(k int, n int, compress bool, tmpDir string) (string, error) {
   175  	for i := 0; i < n; i++ {
   176  		fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10))
   177  		f, err := os.Create(fname)
   178  		if err != nil {
   179  			return "", err
   180  		}
   181  		for j := 0; j < k; j++ {
   182  			num := rand.Float64() * 10000
   183  			min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix()
   184  			max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix()
   185  			delta := max - min
   186  			sec := rand.Int63n(delta) + min
   187  			tm := time.Unix(sec, 0)
   188  			dt := tm.Format("2021-03-01")
   189  			sec = rand.Int63n(delta) + min
   190  			ts := time.Unix(sec, 0).Format(timeFormat)
   191  			sec = rand.Int63n(delta) + min
   192  			tsltz := time.Unix(sec, 0).Format(timeFormat)
   193  			sec = rand.Int63n(delta) + min
   194  			tsntz := time.Unix(sec, 0).Format(timeFormat)
   195  			sec = rand.Int63n(delta) + min
   196  			tstz := time.Unix(sec, 0).Format(timeFormat)
   197  			pct := rand.Float64() * 1000
   198  			ratio := fmt.Sprintf("%.2f", rand.Float64()*1000)
   199  			rec := fmt.Sprintf("%v,%v,%v,%v,%v,%v,%v,%v\n", num, dt, ts, tsltz, tsntz, tstz, pct, ratio)
   200  			f.Write([]byte(rec))
   201  		}
   202  		f.Close()
   203  		if compress {
   204  			if !isWindows {
   205  				gzipCmd := exec.Command("gzip", filepath.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10)))
   206  				gzipOut, err := gzipCmd.StdoutPipe()
   207  				if err != nil {
   208  					return "", err
   209  				}
   210  				gzipErr, err := gzipCmd.StderrPipe()
   211  				if err != nil {
   212  					return "", err
   213  				}
   214  				gzipCmd.Start()
   215  				io.ReadAll(gzipOut)
   216  				io.ReadAll(gzipErr)
   217  				gzipCmd.Wait()
   218  			} else {
   219  				fOut, err := os.Create(fname + ".gz")
   220  				if err != nil {
   221  					return "", err
   222  				}
   223  				w := gzip.NewWriter(fOut)
   224  				fIn, err := os.Open(fname)
   225  				if err != nil {
   226  					return "", err
   227  				}
   228  				if _, err = io.Copy(w, fIn); err != nil {
   229  					return "", err
   230  				}
   231  				w.Close()
   232  				fOut.Close()
   233  				fIn.Close()
   234  			}
   235  		}
   236  	}
   237  	return tmpDir, nil
   238  }