github.com/snowflakedb/gosnowflake@v1.9.0/file_util.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"compress/gzip"
     8  	"crypto/sha256"
     9  	"encoding/base64"
    10  	"io"
    11  	"net/url"
    12  	"os"
    13  	usr "os/user"
    14  	"path/filepath"
    15  	"strings"
    16  )
    17  
    18  type snowflakeFileUtil struct {
    19  }
    20  
    21  const (
    22  	fileChunkSize                 = 16 * 4 * 1024
    23  	readWriteFileMode os.FileMode = 0666
    24  )
    25  
    26  func (util *snowflakeFileUtil) compressFileWithGzipFromStream(srcStream **bytes.Buffer) (*bytes.Buffer, int, error) {
    27  	r := getReaderFromBuffer(srcStream)
    28  	buf, err := io.ReadAll(r)
    29  	if err != nil {
    30  		return nil, -1, err
    31  	}
    32  	var c bytes.Buffer
    33  	w := gzip.NewWriter(&c)
    34  	w.Write(buf) // write buf to gzip writer
    35  	w.Close()
    36  	return &c, c.Len(), nil
    37  }
    38  
    39  func (util *snowflakeFileUtil) compressFileWithGzip(fileName string, tmpDir string) (string, int64, error) {
    40  	basename := baseName(fileName)
    41  	gzipFileName := filepath.Join(tmpDir, basename+"_c.gz")
    42  
    43  	fr, err := os.Open(fileName)
    44  	if err != nil {
    45  		return "", -1, err
    46  	}
    47  	defer fr.Close()
    48  	fw, err := os.OpenFile(gzipFileName, os.O_WRONLY|os.O_CREATE, readWriteFileMode)
    49  	if err != nil {
    50  		return "", -1, err
    51  	}
    52  	gzw := gzip.NewWriter(fw)
    53  	defer gzw.Close()
    54  	io.Copy(gzw, fr)
    55  
    56  	stat, err := os.Stat(gzipFileName)
    57  	if err != nil {
    58  		return "", -1, err
    59  	}
    60  	return gzipFileName, stat.Size(), nil
    61  }
    62  
    63  func (util *snowflakeFileUtil) getDigestAndSizeForStream(stream **bytes.Buffer) (string, int64, error) {
    64  	m := sha256.New()
    65  	r := getReaderFromBuffer(stream)
    66  	chunk := make([]byte, fileChunkSize)
    67  
    68  	for {
    69  		n, err := r.Read(chunk)
    70  		if err == io.EOF {
    71  			break
    72  		} else if err != nil {
    73  			return "", 0, err
    74  		}
    75  		m.Write(chunk[:n])
    76  	}
    77  	return base64.StdEncoding.EncodeToString(m.Sum(nil)), int64((*stream).Len()), nil
    78  }
    79  
    80  func (util *snowflakeFileUtil) getDigestAndSizeForFile(fileName string) (string, int64, error) {
    81  	f, err := os.Open(fileName)
    82  	if err != nil {
    83  		return "", 0, err
    84  	}
    85  	defer f.Close()
    86  
    87  	var total int64
    88  	m := sha256.New()
    89  	chunk := make([]byte, fileChunkSize)
    90  
    91  	for {
    92  		n, err := f.Read(chunk)
    93  		if err == io.EOF {
    94  			break
    95  		} else if err != nil {
    96  			return "", 0, err
    97  		}
    98  		total += int64(n)
    99  		m.Write(chunk[:n])
   100  	}
   101  	f.Seek(0, io.SeekStart)
   102  	return base64.StdEncoding.EncodeToString(m.Sum(nil)), total, nil
   103  }
   104  
   105  // file metadata for PUT/GET
   106  type fileMetadata struct {
   107  	name               string
   108  	sfa                *snowflakeFileTransferAgent
   109  	stageLocationType  cloudType
   110  	resStatus          resultStatus
   111  	stageInfo          *execResponseStageInfo
   112  	encryptionMaterial *snowflakeFileEncryption
   113  
   114  	srcFileName        string
   115  	realSrcFileName    string
   116  	srcFileSize        int64
   117  	srcCompressionType *compressionType
   118  	uploadSize         int64
   119  	dstFileSize        int64
   120  	dstFileName        string
   121  	dstCompressionType *compressionType
   122  
   123  	client             cloudClient // *s3.Client (S3), *azblob.ContainerURL (Azure), string (GCS)
   124  	requireCompress    bool
   125  	parallel           int64
   126  	sha256Digest       string
   127  	overwrite          bool
   128  	tmpDir             string
   129  	errorDetails       error
   130  	lastError          error
   131  	noSleepingTime     bool
   132  	lastMaxConcurrency int
   133  	localLocation      string
   134  	options            *SnowflakeFileTransferOptions
   135  
   136  	/* streaming PUT */
   137  	srcStream     *bytes.Buffer
   138  	realSrcStream *bytes.Buffer
   139  
   140  	/* GCS */
   141  	presignedURL                *url.URL
   142  	gcsFileHeaderDigest         string
   143  	gcsFileHeaderContentLength  int64
   144  	gcsFileHeaderEncryptionMeta *encryptMetadata
   145  
   146  	/* mock */
   147  	mockUploader    s3UploadAPI
   148  	mockDownloader  s3DownloadAPI
   149  	mockHeader      s3HeaderAPI
   150  	mockGcsClient   gcsAPI
   151  	mockAzureClient azureAPI
   152  }
   153  
   154  type fileTransferResultType struct {
   155  	name               string
   156  	srcFileName        string
   157  	dstFileName        string
   158  	srcFileSize        int64
   159  	dstFileSize        int64
   160  	srcCompressionType *compressionType
   161  	dstCompressionType *compressionType
   162  	resStatus          resultStatus
   163  	errorDetails       error
   164  }
   165  
   166  type fileHeader struct {
   167  	digest             string
   168  	contentLength      int64
   169  	encryptionMetadata *encryptMetadata
   170  }
   171  
   172  func getReaderFromBuffer(src **bytes.Buffer) io.Reader {
   173  	var b bytes.Buffer
   174  	tee := io.TeeReader(*src, &b) // read src to buf
   175  	*src = &b                     // revert pointer back
   176  	return tee
   177  }
   178  
   179  // baseName returns the pathname of the path provided
   180  func baseName(path string) string {
   181  	base := filepath.Base(path)
   182  	if base == "." || base == "/" {
   183  		return ""
   184  	}
   185  	if len(base) > 1 && (path[len(path)-1:] == "." || path[len(path)-1:] == "/") {
   186  		return ""
   187  	}
   188  	return base
   189  }
   190  
   191  // expandUser returns the argument with an initial component of ~
   192  func expandUser(path string) (string, error) {
   193  	usr, err := usr.Current()
   194  	if err != nil {
   195  		return "", err
   196  	}
   197  	dir := usr.HomeDir
   198  	if path == "~" {
   199  		path = dir
   200  	} else if strings.HasPrefix(path, "~/") {
   201  		path = filepath.Join(dir, path[2:])
   202  	}
   203  	return path, nil
   204  }
   205  
   206  // getDirectory retrieves the current working directory
   207  func getDirectory() (string, error) {
   208  	ex, err := os.Executable()
   209  	if err != nil {
   210  		return "", err
   211  	}
   212  	return filepath.Dir(ex), nil
   213  }