github.com/snowflakedb/gosnowflake@v1.9.0/file_util.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "compress/gzip" 8 "crypto/sha256" 9 "encoding/base64" 10 "io" 11 "net/url" 12 "os" 13 usr "os/user" 14 "path/filepath" 15 "strings" 16 ) 17 18 type snowflakeFileUtil struct { 19 } 20 21 const ( 22 fileChunkSize = 16 * 4 * 1024 23 readWriteFileMode os.FileMode = 0666 24 ) 25 26 func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes.Buffer) (*bytes.Buffer, int, error) { 27 r := getReaderFromBuffer(srcStream) 28 buf, err := io.ReadAll(r) 29 if err != nil { 30 return nil, -1, err 31 } 32 var c bytes.Buffer 33 w := gzip.NewWriter(&c) 34 w.Write(buf) // write buf to gzip writer 35 w.Close() 36 return &c, c.Len(), nil 37 } 38 39 func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir string) (string, int64, error) { 40 basename := baseName(fileName) 41 gzipFileName := filepath.Join(tmpDir, basename+"_c.gz") 42 43 fr, err := os.Open(fileName) 44 if err != nil { 45 return "", -1, err 46 } 47 defer fr.Close() 48 fw, err := os.OpenFile(gzipFileName, os.O_WRONLY|os.O_CREATE, readWriteFileMode) 49 if err != nil { 50 return "", -1, err 51 } 52 gzw := gzip.NewWriter(fw) 53 defer gzw.Close() 54 io.Copy(gzw, fr) 55 56 stat, err := os.Stat(gzipFileName) 57 if err != nil { 58 return "", -1, err 59 } 60 return gzipFileName, stat.Size(), nil 61 } 62 63 func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) { 64 m := sha256.New() 65 r := getReaderFromBuffer(stream) 66 chunk := make([]byte, fileChunkSize) 67 68 for { 69 n, err := r.Read(chunk) 70 if err == io.EOF { 71 break 72 } else if err != nil { 73 return "", 0, err 74 } 75 m.Write(chunk[:n]) 76 } 77 return base64.StdEncoding.EncodeToString(m.Sum(nil)), int64((*stream).Len()), nil 78 } 79 80 func (util *snowflakeFileUtil) getDigestAndSizeForFile(fileName string) (string, int64, error) { 81 f, err := os.Open(fileName) 82 if err != nil { 83 return "", 0, err 84 } 85 defer f.Close() 86 87 var total int64 88 m := sha256.New() 89 chunk := make([]byte, fileChunkSize) 90 91 for { 92 n, err := f.Read(chunk) 93 if err == io.EOF { 94 break 95 } else if err != nil { 96 return "", 0, err 97 } 98 total += int64(n) 99 m.Write(chunk[:n]) 100 } 101 f.Seek(0, io.SeekStart) 102 return base64.StdEncoding.EncodeToString(m.Sum(nil)), total, nil 103 } 104 105 // file metadata for PUT/GET 106 type fileMetadata struct { 107 name string 108 sfa *snowflakeFileTransferAgent 109 stageLocationType cloudType 110 resStatus resultStatus 111 stageInfo *execResponseStageInfo 112 encryptionMaterial *snowflakeFileEncryption 113 114 srcFileName string 115 realSrcFileName string 116 srcFileSize int64 117 srcCompressionType *compressionType 118 uploadSize int64 119 dstFileSize int64 120 dstFileName string 121 dstCompressionType *compressionType 122 123 client cloudClient // *s3.Client (S3), *azblob.ContainerURL (Azure), string (GCS) 124 requireCompress bool 125 parallel int64 126 sha256Digest string 127 overwrite bool 128 tmpDir string 129 errorDetails error 130 lastError error 131 noSleepingTime bool 132 lastMaxConcurrency int 133 localLocation string 134 options *SnowflakeFileTransferOptions 135 136 /* streaming PUT */ 137 srcStream *bytes.Buffer 138 realSrcStream *bytes.Buffer 139 140 /* GCS */ 141 presignedURL *url.URL 142 gcsFileHeaderDigest string 143 gcsFileHeaderContentLength int64 144 gcsFileHeaderEncryptionMeta *encryptMetadata 145 146 /* mock */ 147 mockUploader s3UploadAPI 148 mockDownloader s3DownloadAPI 149 mockHeader s3HeaderAPI 150 mockGcsClient gcsAPI 151 mockAzureClient azureAPI 152 } 153 154 type fileTransferResultType struct { 155 name string 156 srcFileName string 157 dstFileName string 158 srcFileSize int64 159 dstFileSize int64 160 srcCompressionType *compressionType 161 dstCompressionType *compressionType 162 resStatus resultStatus 163 errorDetails error 164 } 165 166 type fileHeader struct { 167 digest string 168 contentLength int64 169 encryptionMetadata *encryptMetadata 170 } 171 172 func getReaderFromBuffer(src **bytes.Buffer) io.Reader { 173 var b bytes.Buffer 174 tee := io.TeeReader(*src, &b) // read src to buf 175 *src = &b // revert pointer back 176 return tee 177 } 178 179 // baseName returns the pathname of the path provided 180 func baseName(path string) string { 181 base := filepath.Base(path) 182 if base == "." || base == "/" { 183 return "" 184 } 185 if len(base) > 1 && (path[len(path)-1:] == "." || path[len(path)-1:] == "/") { 186 return "" 187 } 188 return base 189 } 190 191 // expandUser returns the argument with an initial component of ~ 192 func expandUser(path string) (string, error) { 193 usr, err := usr.Current() 194 if err != nil { 195 return "", err 196 } 197 dir := usr.HomeDir 198 if path == "~" { 199 path = dir 200 } else if strings.HasPrefix(path, "~/") { 201 path = filepath.Join(dir, path[2:]) 202 } 203 return path, nil 204 } 205 206 // getDirectory retrieves the current working directory 207 func getDirectory() (string, error) { 208 ex, err := os.Executable() 209 if err != nil { 210 return "", err 211 } 212 return filepath.Dir(ex), nil 213 }