github.com/snowflakedb/gosnowflake@v1.9.0/encrypt_util_test.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bufio" 7 "compress/gzip" 8 "fmt" 9 "io" 10 "math/rand" 11 "os" 12 "os/exec" 13 "path" 14 "path/filepath" 15 "strconv" 16 "testing" 17 "time" 18 ) 19 20 const timeFormat = "2006-01-02T15:04:05" 21 22 type encryptDecryptTestFile struct { 23 numberOfBytesInEachRow int 24 numberOfLines int 25 } 26 27 func TestEncryptDecryptFile(t *testing.T) { 28 encMat := snowflakeFileEncryption{ 29 "ztke8tIdVt1zmlQIZm0BMA==", 30 "123873c7-3a66-40c4-ab89-e3722fbccce1", 31 3112, 32 } 33 data := "test data" 34 inputFile := "test_encrypt_decrypt_file" 35 36 fd, err := os.Create(inputFile) 37 if err != nil { 38 t.Error(err) 39 } 40 defer fd.Close() 41 defer os.Remove(inputFile) 42 if _, err = fd.Write([]byte(data)); err != nil { 43 t.Error(err) 44 } 45 46 metadata, encryptedFile, err := encryptFile(&encMat, inputFile, 0, "") 47 if err != nil { 48 t.Error(err) 49 } 50 defer os.Remove(encryptedFile) 51 decryptedFile, err := decryptFile(metadata, &encMat, encryptedFile, 0, "") 52 if err != nil { 53 t.Error(err) 54 } 55 defer os.Remove(decryptedFile) 56 57 fd, err = os.Open(decryptedFile) 58 if err != nil { 59 t.Error(err) 60 } 61 defer fd.Close() 62 content, err := io.ReadAll(fd) 63 if err != nil { 64 t.Error(err) 65 } 66 if string(content) != data { 67 t.Fatalf("data did not match content. expected: %v, got: %v", data, string(content)) 68 } 69 } 70 71 func TestEncryptDecryptFilePadding(t *testing.T) { 72 encMat := snowflakeFileEncryption{ 73 "ztke8tIdVt1zmlQIZm0BMA==", 74 "123873c7-3a66-40c4-ab89-e3722fbccce1", 75 3112, 76 } 77 78 testcases := []encryptDecryptTestFile{ 79 // File size is a multiple of 65536 bytes (chunkSize) 80 {numberOfBytesInEachRow: 8, numberOfLines: 16384}, 81 {numberOfBytesInEachRow: 16, numberOfLines: 4096}, 82 // File size is not a multiple of 65536 bytes (chunkSize) 83 {numberOfBytesInEachRow: 8, numberOfLines: 10240}, 84 {numberOfBytesInEachRow: 16, numberOfLines: 6144}, 85 // The second chunk's size is a multiple of 16 bytes (aes.BlockSize) 86 {numberOfBytesInEachRow: 16, numberOfLines: 4097}, 87 // The second chunk's size is not a multiple of 16 bytes (aes.BlockSize) 88 {numberOfBytesInEachRow: 12, numberOfLines: 5462}, 89 {numberOfBytesInEachRow: 10, numberOfLines: 6556}, 90 } 91 92 for _, test := range testcases { 93 t.Run(fmt.Sprintf("%v_%v", test.numberOfBytesInEachRow, test.numberOfLines), func(t *testing.T) { 94 tmpDir, err := generateKLinesOfNByteRows(test.numberOfLines, test.numberOfBytesInEachRow, t.TempDir()) 95 if err != nil { 96 t.Error(err) 97 } 98 99 encryptDecryptFile(t, encMat, test.numberOfLines, tmpDir) 100 }) 101 } 102 } 103 104 func TestEncryptDecryptLargeFile(t *testing.T) { 105 encMat := snowflakeFileEncryption{ 106 "ztke8tIdVt1zmlQIZm0BMA==", 107 "123873c7-3a66-40c4-ab89-e3722fbccce1", 108 3112, 109 } 110 111 numberOfFiles := 1 112 numberOfLines := 10000 113 tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir()) 114 if err != nil { 115 t.Error(err) 116 } 117 118 encryptDecryptFile(t, encMat, numberOfLines, tmpDir) 119 } 120 121 func encryptDecryptFile(t *testing.T, encMat snowflakeFileEncryption, expected int, tmpDir string) { 122 files, err := filepath.Glob(filepath.Join(tmpDir, "file*")) 123 if err != nil { 124 t.Error(err) 125 } 126 inputFile := files[0] 127 128 metadata, encryptedFile, err := encryptFile(&encMat, inputFile, 0, tmpDir) 129 if err != nil { 130 t.Error(err) 131 } 132 defer os.Remove(encryptedFile) 133 decryptedFile, err := decryptFile(metadata, &encMat, encryptedFile, 0, tmpDir) 134 if err != nil { 135 t.Error(err) 136 } 137 defer os.Remove(decryptedFile) 138 139 cnt := 0 140 fd, err := os.Open(decryptedFile) 141 if err != nil { 142 t.Error(err) 143 } 144 defer fd.Close() 145 146 scanner := bufio.NewScanner(fd) 147 for scanner.Scan() { 148 cnt++ 149 } 150 if err = scanner.Err(); err != nil { 151 t.Error(err) 152 } 153 if cnt != expected { 154 t.Fatalf("incorrect number of lines. expected: %v, got: %v", expected, cnt) 155 } 156 } 157 158 func generateKLinesOfNByteRows(numLines int, numBytes int, tmpDir string) (string, error) { 159 fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(numLines*numBytes), 10)) 160 f, err := os.Create(fname) 161 if err != nil { 162 return "", err 163 } 164 165 for j := 0; j < numLines; j++ { 166 str := randomString(numBytes - 1) // \n is the last character 167 rec := fmt.Sprintf("%v\n", str) 168 f.Write([]byte(rec)) 169 } 170 f.Close() 171 return tmpDir, nil 172 } 173 174 func generateKLinesOfNFiles(k int, n int, compress bool, tmpDir string) (string, error) { 175 for i := 0; i < n; i++ { 176 fname := path.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10)) 177 f, err := os.Create(fname) 178 if err != nil { 179 return "", err 180 } 181 for j := 0; j < k; j++ { 182 num := rand.Float64() * 10000 183 min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() 184 max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() 185 delta := max - min 186 sec := rand.Int63n(delta) + min 187 tm := time.Unix(sec, 0) 188 dt := tm.Format("2021-03-01") 189 sec = rand.Int63n(delta) + min 190 ts := time.Unix(sec, 0).Format(timeFormat) 191 sec = rand.Int63n(delta) + min 192 tsltz := time.Unix(sec, 0).Format(timeFormat) 193 sec = rand.Int63n(delta) + min 194 tsntz := time.Unix(sec, 0).Format(timeFormat) 195 sec = rand.Int63n(delta) + min 196 tstz := time.Unix(sec, 0).Format(timeFormat) 197 pct := rand.Float64() * 1000 198 ratio := fmt.Sprintf("%.2f", rand.Float64()*1000) 199 rec := fmt.Sprintf("%v,%v,%v,%v,%v,%v,%v,%v\n", num, dt, ts, tsltz, tsntz, tstz, pct, ratio) 200 f.Write([]byte(rec)) 201 } 202 f.Close() 203 if compress { 204 if !isWindows { 205 gzipCmd := exec.Command("gzip", filepath.Join(tmpDir, "file"+strconv.FormatInt(int64(i), 10))) 206 gzipOut, err := gzipCmd.StdoutPipe() 207 if err != nil { 208 return "", err 209 } 210 gzipErr, err := gzipCmd.StderrPipe() 211 if err != nil { 212 return "", err 213 } 214 gzipCmd.Start() 215 io.ReadAll(gzipOut) 216 io.ReadAll(gzipErr) 217 gzipCmd.Wait() 218 } else { 219 fOut, err := os.Create(fname + ".gz") 220 if err != nil { 221 return "", err 222 } 223 w := gzip.NewWriter(fOut) 224 fIn, err := os.Open(fname) 225 if err != nil { 226 return "", err 227 } 228 if _, err = io.Copy(w, fIn); err != nil { 229 return "", err 230 } 231 w.Close() 232 fOut.Close() 233 fIn.Close() 234 } 235 } 236 } 237 return tmpDir, nil 238 }