github.com/snowflakedb/gosnowflake@v1.9.0/put_get_test.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"compress/gzip"
     8  	"context"
     9  	"fmt"
    10  	"io"
    11  	"math/rand"
    12  	"os"
    13  	"os/user"
    14  	"path/filepath"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  )
    20  
    21  const createStageStmt = "CREATE OR REPLACE STAGE %v URL = '%v' CREDENTIALS = (%v)"
    22  
    23  func randomString(n int) string {
    24  	rand.Seed(time.Now().UnixNano())
    25  	alpha := []rune("abcdefghijklmnopqrstuvwxyz")
    26  	b := make([]rune, n)
    27  	for i := range b {
    28  		b[i] = alpha[rand.Intn(len(alpha))]
    29  	}
    30  	return string(b)
    31  }
    32  
    33  func TestPutError(t *testing.T) {
    34  	if isWindows {
    35  		t.Skip("permission model is different")
    36  	}
    37  	tmpDir := t.TempDir()
    38  	file1 := filepath.Join(tmpDir, "file1")
    39  	remoteLocation := filepath.Join(tmpDir, "remote_loc")
    40  	f, err := os.Create(file1)
    41  	if err != nil {
    42  		t.Error(err)
    43  	}
    44  	defer f.Close()
    45  	f.WriteString("test1")
    46  	os.Chmod(file1, 0000)
    47  	defer os.Chmod(file1, 0644)
    48  
    49  	data := &execResponseData{
    50  		Command:           string(uploadCommand),
    51  		AutoCompress:      false,
    52  		SrcLocations:      []string{file1},
    53  		SourceCompression: "none",
    54  		StageInfo: execResponseStageInfo{
    55  			Location:     remoteLocation,
    56  			LocationType: string(local),
    57  			Path:         "remote_loc",
    58  		},
    59  	}
    60  
    61  	fta := &snowflakeFileTransferAgent{
    62  		data: data,
    63  		options: &SnowflakeFileTransferOptions{
    64  			RaisePutGetError: false,
    65  		},
    66  		sc: &snowflakeConn{
    67  			cfg: &Config{},
    68  		},
    69  	}
    70  	if err = fta.execute(); err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	if _, err = fta.result(); err != nil {
    74  		t.Fatal(err)
    75  	}
    76  
    77  	fta = &snowflakeFileTransferAgent{
    78  		data: data,
    79  		options: &SnowflakeFileTransferOptions{
    80  			RaisePutGetError: true,
    81  		},
    82  		sc: &snowflakeConn{
    83  			cfg: &Config{},
    84  		},
    85  	}
    86  	if err = fta.execute(); err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	if _, err = fta.result(); err == nil {
    90  		t.Fatalf("should raise permission error")
    91  	}
    92  }
    93  
    94  func TestPercentage(t *testing.T) {
    95  	testcases := []struct {
    96  		seen     int64
    97  		size     float64
    98  		expected float64
    99  	}{
   100  		{0, 0, 1.0},
   101  		{20, 0, 1.0},
   102  		{40, 20, 1.0},
   103  		{14, 28, 0.5},
   104  	}
   105  	for _, test := range testcases {
   106  		t.Run(fmt.Sprintf("%v_%v_%v", test.seen, test.size, test.expected), func(t *testing.T) {
   107  			spp := snowflakeProgressPercentage{}
   108  			if spp.percent(test.seen, test.size) != test.expected {
   109  				t.Fatalf("percentage conversion failed. %v/%v, expected: %v, got: %v",
   110  					test.seen, test.size, test.expected, spp.percent(test.seen, test.size))
   111  			}
   112  		})
   113  	}
   114  }
   115  
   116  type tcPutGetData struct {
   117  	dir                string
   118  	awsAccessKeyID     string
   119  	awsSecretAccessKey string
   120  	stage              string
   121  	warehouse          string
   122  	database           string
   123  	userBucket         string
   124  }
   125  
   126  func cleanupPut(dbt *DBTest, td *tcPutGetData) {
   127  	dbt.mustExec("drop database " + td.database)
   128  	dbt.mustExec("drop warehouse " + td.warehouse)
   129  }
   130  
   131  func getAWSCredentials() (string, string, string, error) {
   132  	keyID, ok := os.LookupEnv("AWS_ACCESS_KEY_ID")
   133  	if !ok {
   134  		return "", "", "", fmt.Errorf("key id invalid")
   135  	}
   136  	secretKey, ok := os.LookupEnv("AWS_SECRET_ACCESS_KEY")
   137  	if !ok {
   138  		return keyID, "", "", fmt.Errorf("secret key invalid")
   139  	}
   140  	bucket, present := os.LookupEnv("SF_AWS_USER_BUCKET")
   141  	if !present {
   142  		user, err := user.Current()
   143  		if err != nil {
   144  			return keyID, secretKey, "", err
   145  		}
   146  		bucket = fmt.Sprintf("sfc-eng-regression/%v/reg", user.Username)
   147  	}
   148  	return keyID, secretKey, bucket, nil
   149  }
   150  
   151  func createTestData(dbt *DBTest) (*tcPutGetData, error) {
   152  	keyID, secretKey, bucket, err := getAWSCredentials()
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	uniqueName := randomString(10)
   157  	database := fmt.Sprintf("%v_db", uniqueName)
   158  	wh := fmt.Sprintf("%v_wh", uniqueName)
   159  
   160  	dir, err := os.Getwd()
   161  	if err != nil {
   162  		return nil, err
   163  	}
   164  	ret := tcPutGetData{
   165  		dir,
   166  		keyID,
   167  		secretKey,
   168  		fmt.Sprintf("%v_stage", uniqueName),
   169  		wh,
   170  		database,
   171  		bucket,
   172  	}
   173  
   174  	if _, err = dbt.exec("use role sysadmin"); err != nil {
   175  		return nil, err
   176  	}
   177  	dbt.mustExec(fmt.Sprintf(
   178  		"create or replace warehouse %v warehouse_size='small' "+
   179  			"warehouse_type='standard' auto_suspend=1800", wh))
   180  	dbt.mustExec("create or replace database " + database)
   181  	dbt.mustExec("create or replace schema gotesting_schema")
   182  	dbt.mustExec("create or replace file format VSV type = 'CSV' " +
   183  		"field_delimiter='|' error_on_column_count_mismatch=false")
   184  	return &ret, nil
   185  }
   186  
   187  func TestPutLocalFile(t *testing.T) {
   188  	if runningOnGithubAction() && !runningOnAWS() {
   189  		t.Skip("skipping non aws environment")
   190  	}
   191  	runDBTest(t, func(dbt *DBTest) {
   192  		data, err := createTestData(dbt)
   193  		if err != nil {
   194  			t.Skip("snowflake admin account not accessible")
   195  		}
   196  		defer cleanupPut(dbt, data)
   197  		dbt.mustExec("use warehouse " + data.warehouse)
   198  		dbt.mustExec("alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false")
   199  		dbt.mustExec("use schema " + data.database + ".gotesting_schema")
   200  		execQuery := fmt.Sprintf(
   201  			`create or replace table gotest_putget_t1 (c1 STRING, c2 STRING,
   202  			c3 STRING, c4 STRING, c5 STRING, c6 STRING, c7 STRING, c8 STRING,
   203  			c9 STRING) stage_file_format = ( field_delimiter = '|'
   204  			error_on_column_count_mismatch=false) stage_copy_options =
   205  			(purge=false) stage_location = (url = 's3://%v/%v' credentials =
   206  			(AWS_KEY_ID='%v' AWS_SECRET_KEY='%v'))`,
   207  			data.userBucket,
   208  			data.stage,
   209  			data.awsAccessKeyID,
   210  			data.awsSecretAccessKey)
   211  		dbt.mustExec(execQuery)
   212  		defer dbt.mustExec("drop table if exists gotest_putget_t1")
   213  
   214  		execQuery = fmt.Sprintf(`put file://%v/test_data/orders_10*.csv
   215  			@%%gotest_putget_t1`, data.dir)
   216  		dbt.mustExec(execQuery)
   217  		dbt.mustQueryAssertCount("ls @%gotest_putget_t1", 2)
   218  
   219  		var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string
   220  		rows := dbt.mustQuery("copy into gotest_putget_t1")
   221  		defer rows.Close()
   222  		for rows.Next() {
   223  			rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)
   224  			if s1 != "LOADED" {
   225  				t.Fatal("not loaded")
   226  			}
   227  		}
   228  
   229  		rows2 := dbt.mustQuery("select count(*) from gotest_putget_t1")
   230  		defer rows2.Close()
   231  		var i int
   232  		if rows2.Next() {
   233  			rows2.Scan(&i)
   234  			if i != 75 {
   235  				t.Fatalf("expected 75 rows, got %v", i)
   236  			}
   237  		}
   238  
   239  		rows3 := dbt.mustQuery(`select STATUS from information_schema .load_history where table_name='gotest_putget_t1'`)
   240  		rows3.Close()
   241  		if rows3.Next() {
   242  			rows3.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)
   243  			if s1 != "LOADED" {
   244  				t.Fatal("not loaded")
   245  			}
   246  		}
   247  	})
   248  }
   249  
   250  func TestPutWithAutoCompressFalse(t *testing.T) {
   251  	if runningOnGithubAction() && !runningOnAWS() {
   252  		t.Skip("skipping non aws environment")
   253  	}
   254  	tmpDir := t.TempDir()
   255  	testData := filepath.Join(tmpDir, "data.txt")
   256  	f, err := os.Create(testData)
   257  	if err != nil {
   258  		t.Error(err)
   259  	}
   260  	f.WriteString("test1,test2\ntest3,test4")
   261  	f.Sync()
   262  	defer f.Close()
   263  
   264  	runDBTest(t, func(dbt *DBTest) {
   265  		if _, err = dbt.exec("use role sysadmin"); err != nil {
   266  			t.Skip("snowflake admin account not accessible")
   267  		}
   268  		dbt.mustExec("rm @~/test_put_uncompress_file")
   269  		sqlText := fmt.Sprintf("put file://%v @~/test_put_uncompress_file auto_compress=FALSE", testData)
   270  		sqlText = strings.ReplaceAll(sqlText, "\\", "\\\\")
   271  		dbt.mustExec(sqlText)
   272  		defer dbt.mustExec("rm @~/test_put_uncompress_file")
   273  		rows := dbt.mustQuery("ls @~/test_put_uncompress_file")
   274  		defer rows.Close()
   275  		var file, s1, s2, s3 string
   276  		if rows.Next() {
   277  			if err := rows.Scan(&file, &s1, &s2, &s3); err != nil {
   278  				t.Fatal(err)
   279  			}
   280  		}
   281  		if !strings.Contains(file, "test_put_uncompress_file/data.txt") {
   282  			t.Fatalf("should contain file. got: %v", file)
   283  		}
   284  		if strings.Contains(file, "data.txt.gz") {
   285  			t.Fatalf("should not contain file. got: %v", file)
   286  		}
   287  	})
   288  }
   289  
   290  func TestPutOverwrite(t *testing.T) {
   291  	tmpDir := t.TempDir()
   292  	testData := filepath.Join(tmpDir, "data.txt")
   293  	f, err := os.Create(testData)
   294  	if err != nil {
   295  		t.Error(err)
   296  	}
   297  	f.WriteString("test1,test2\ntest3,test4\n")
   298  	f.Close()
   299  
   300  	runDBTest(t, func(dbt *DBTest) {
   301  		dbt.mustExec("rm @~/test_put_overwrite")
   302  
   303  		f, _ = os.Open(testData)
   304  		rows := dbt.mustQueryContext(
   305  			WithFileStream(context.Background(), f),
   306  			fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
   307  				strings.ReplaceAll(testData, "\\", "/")))
   308  		defer rows.Close()
   309  		f.Close()
   310  		defer dbt.mustExec("rm @~/test_put_overwrite")
   311  		var s0, s1, s2, s3, s4, s5, s6, s7 string
   312  		if rows.Next() {
   313  			if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   314  				t.Fatal(err)
   315  			}
   316  		}
   317  		if s6 != uploaded.String() {
   318  			t.Fatalf("expected UPLOADED, got %v", s6)
   319  		}
   320  
   321  		rows = dbt.mustQuery("ls @~/test_put_overwrite")
   322  		defer rows.Close()
   323  		assertTrueF(t, rows.Next(), "expected new rows")
   324  		if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
   325  			t.Fatal(err)
   326  		}
   327  		md5Column := s2
   328  
   329  		f, _ = os.Open(testData)
   330  		rows = dbt.mustQueryContext(
   331  			WithFileStream(context.Background(), f),
   332  			fmt.Sprintf("put 'file://%v' @~/test_put_overwrite",
   333  				strings.ReplaceAll(testData, "\\", "/")))
   334  		defer rows.Close()
   335  		f.Close()
   336  		assertTrueF(t, rows.Next(), "expected new rows")
   337  		if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   338  			t.Fatal(err)
   339  		}
   340  		if s6 != skipped.String() {
   341  			t.Fatalf("expected SKIPPED, got %v", s6)
   342  		}
   343  
   344  		rows = dbt.mustQuery("ls @~/test_put_overwrite")
   345  		defer rows.Close()
   346  		assertTrueF(t, rows.Next(), "expected new rows")
   347  
   348  		if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
   349  			t.Fatal(err)
   350  		}
   351  		if s2 != md5Column {
   352  			t.Fatal("The MD5 column should have stayed the same")
   353  		}
   354  
   355  		f, _ = os.Open(testData)
   356  		rows = dbt.mustQueryContext(
   357  			WithFileStream(context.Background(), f),
   358  			fmt.Sprintf("put 'file://%v' @~/test_put_overwrite overwrite=true",
   359  				strings.ReplaceAll(testData, "\\", "/")))
   360  		defer rows.Close()
   361  		f.Close()
   362  		assertTrueF(t, rows.Next(), "expected new rows")
   363  		if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   364  			t.Fatal(err)
   365  		}
   366  		if s6 != uploaded.String() {
   367  			t.Fatalf("expected UPLOADED, got %v", s6)
   368  		}
   369  
   370  		rows = dbt.mustQuery("ls @~/test_put_overwrite")
   371  		defer rows.Close()
   372  		assertTrueF(t, rows.Next(), "expected new rows")
   373  		if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
   374  			t.Fatal(err)
   375  		}
   376  		if s0 != fmt.Sprintf("test_put_overwrite/%v.gz", baseName(testData)) {
   377  			t.Fatalf("expected test_put_overwrite/%v.gz, got %v", baseName(testData), s0)
   378  		}
   379  		if s2 == md5Column {
   380  			t.Fatalf("file should have been overwritten.")
   381  		}
   382  	})
   383  }
   384  
   385  func TestPutGetFile(t *testing.T) {
   386  	testPutGet(t, false)
   387  }
   388  
   389  func TestPutGetStream(t *testing.T) {
   390  	testPutGet(t, true)
   391  }
   392  
   393  func testPutGet(t *testing.T, isStream bool) {
   394  	tmpDir := t.TempDir()
   395  	fname := filepath.Join(tmpDir, "test_put_get.txt.gz")
   396  	originalContents := "123,test1\n456,test2\n"
   397  	tableName := randomString(5)
   398  
   399  	var b bytes.Buffer
   400  	gzw := gzip.NewWriter(&b)
   401  	gzw.Write([]byte(originalContents))
   402  	gzw.Close()
   403  	if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil {
   404  		t.Fatal("could not write to gzip file")
   405  	}
   406  
   407  	runDBTest(t, func(dbt *DBTest) {
   408  		dbt.mustExec("create or replace table " + tableName +
   409  			" (a int, b string)")
   410  		defer dbt.mustExec("drop table " + tableName)
   411  		fileStream, err := os.Open(fname)
   412  		if err != nil {
   413  			t.Error(err)
   414  		}
   415  		defer fileStream.Close()
   416  
   417  		var sqlText string
   418  		var rows *RowsExtended
   419  		sql := "put 'file://%v' @%%%v auto_compress=true parallel=30"
   420  		ctx := context.Background()
   421  		if isStream {
   422  			sqlText = fmt.Sprintf(
   423  				sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
   424  			rows = dbt.mustQueryContext(WithFileStream(ctx, fileStream), sqlText)
   425  		} else {
   426  			sqlText = fmt.Sprintf(
   427  				sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
   428  			rows = dbt.mustQuery(sqlText)
   429  		}
   430  		defer rows.Close()
   431  
   432  		var s0, s1, s2, s3, s4, s5, s6, s7 string
   433  		assertTrueF(t, rows.Next(), "expected new rows")
   434  		if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   435  			t.Fatal(err)
   436  		}
   437  		if s6 != uploaded.String() {
   438  			t.Fatalf("expected %v, got: %v", uploaded, s6)
   439  		}
   440  		// check file is PUT
   441  		dbt.mustQueryAssertCount("ls @%"+tableName, 1)
   442  
   443  		dbt.mustExec("copy into " + tableName)
   444  		dbt.mustExec("rm @%" + tableName)
   445  		dbt.mustQueryAssertCount("ls @%"+tableName, 0)
   446  
   447  		dbt.mustExec(fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv
   448  			compression='gzip')`, tableName, tableName))
   449  
   450  		sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir)
   451  		sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
   452  		rows2 := dbt.mustQuery(sqlText)
   453  		defer rows2.Close()
   454  		for rows2.Next() {
   455  			if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil {
   456  				t.Error(err)
   457  			}
   458  			if !strings.HasPrefix(s0, "data_") {
   459  				t.Error("a file was not downloaded by GET")
   460  			}
   461  			if v, err := strconv.Atoi(s1); err != nil || v != 36 {
   462  				t.Error("did not return the right file size")
   463  			}
   464  			if s2 != "DOWNLOADED" {
   465  				t.Error("did not return DOWNLOADED status")
   466  			}
   467  			if s3 != "" {
   468  				t.Errorf("returned %v", s3)
   469  			}
   470  		}
   471  
   472  		files, err := filepath.Glob(filepath.Join(tmpDir, "data_*"))
   473  		if err != nil {
   474  			t.Fatal(err)
   475  		}
   476  		fileName := files[0]
   477  		f, err := os.Open(fileName)
   478  		if err != nil {
   479  			t.Error(err)
   480  		}
   481  		defer f.Close()
   482  		gz, err := gzip.NewReader(f)
   483  		if err != nil {
   484  			t.Error(err)
   485  		}
   486  		defer gz.Close()
   487  		var contents string
   488  		for {
   489  			c := make([]byte, defaultChunkBufferSize)
   490  			if n, err := gz.Read(c); err != nil {
   491  				if err == io.EOF {
   492  					contents = contents + string(c[:n])
   493  					break
   494  				}
   495  				t.Error(err)
   496  			} else {
   497  				contents = contents + string(c[:n])
   498  			}
   499  		}
   500  
   501  		if contents != originalContents {
   502  			t.Error("output is different from the original file")
   503  		}
   504  	})
   505  }
   506  func TestPutGetGcsDownscopedCredential(t *testing.T) {
   507  	if runningOnGithubAction() && !runningOnGCP() {
   508  		t.Skip("skipping non GCP environment")
   509  	}
   510  
   511  	tmpDir, err := os.MkdirTemp("", "put_get")
   512  	if err != nil {
   513  		t.Error(err)
   514  	}
   515  	defer os.RemoveAll(tmpDir)
   516  	fname := filepath.Join(tmpDir, "test_put_get.txt.gz")
   517  	originalContents := "123,test1\n456,test2\n"
   518  	tableName := randomString(5)
   519  
   520  	var b bytes.Buffer
   521  	gzw := gzip.NewWriter(&b)
   522  	gzw.Write([]byte(originalContents))
   523  	gzw.Close()
   524  	if err = os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil {
   525  		t.Fatal("could not write to gzip file")
   526  	}
   527  
   528  	dsn = dsn + "&GCS_USE_DOWNSCOPED_CREDENTIAL=true"
   529  	runDBTest(t, func(dbt *DBTest) {
   530  		dbt.mustExec("create or replace table " + tableName +
   531  			" (a int, b string)")
   532  		fileStream, err := os.Open(fname)
   533  		if err != nil {
   534  			t.Error(err)
   535  		}
   536  		defer func() {
   537  			defer dbt.mustExec("drop table " + tableName)
   538  			if fileStream != nil {
   539  				fileStream.Close()
   540  			}
   541  		}()
   542  
   543  		var sqlText string
   544  		var rows *RowsExtended
   545  		sql := "put 'file://%v' @%%%v auto_compress=true parallel=30"
   546  		sqlText = fmt.Sprintf(
   547  			sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
   548  		rows = dbt.mustQuery(sqlText)
   549  		defer rows.Close()
   550  
   551  		var s0, s1, s2, s3, s4, s5, s6, s7 string
   552  		if rows.Next() {
   553  			if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   554  				t.Fatal(err)
   555  			}
   556  		}
   557  		if s6 != uploaded.String() {
   558  			t.Fatalf("expected %v, got: %v", uploaded, s6)
   559  		}
   560  		// check file is PUT
   561  		dbt.mustQueryAssertCount("ls @%"+tableName, 1)
   562  
   563  		dbt.mustExec("copy into " + tableName)
   564  		dbt.mustExec("rm @%" + tableName)
   565  		dbt.mustQueryAssertCount("ls @%"+tableName, 0)
   566  
   567  		dbt.mustExec(fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv
   568              compression='gzip')`, tableName, tableName))
   569  
   570  		sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir)
   571  		sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
   572  		rows2 := dbt.mustQuery(sqlText)
   573  		defer rows2.Close()
   574  		for rows2.Next() {
   575  			if err = rows2.Scan(&s0, &s1, &s2, &s3); err != nil {
   576  				t.Error(err)
   577  			}
   578  			if !strings.HasPrefix(s0, "data_") {
   579  				t.Error("a file was not downloaded by GET")
   580  			}
   581  			if v, err := strconv.Atoi(s1); err != nil || v != 36 {
   582  				t.Error("did not return the right file size")
   583  			}
   584  			if s2 != "DOWNLOADED" {
   585  				t.Error("did not return DOWNLOADED status")
   586  			}
   587  			if s3 != "" {
   588  				t.Errorf("returned %v", s3)
   589  			}
   590  		}
   591  
   592  		files, err := filepath.Glob(filepath.Join(tmpDir, "data_*"))
   593  		if err != nil {
   594  			t.Fatal(err)
   595  		}
   596  		fileName := files[0]
   597  		f, err := os.Open(fileName)
   598  		if err != nil {
   599  			t.Error(err)
   600  		}
   601  		defer f.Close()
   602  		gz, err := gzip.NewReader(f)
   603  		if err != nil {
   604  			t.Error(err)
   605  		}
   606  		var contents string
   607  		for {
   608  			c := make([]byte, defaultChunkBufferSize)
   609  			if n, err := gz.Read(c); err != nil {
   610  				if err == io.EOF {
   611  					contents = contents + string(c[:n])
   612  					break
   613  				}
   614  				t.Error(err)
   615  			} else {
   616  				contents = contents + string(c[:n])
   617  			}
   618  		}
   619  
   620  		if contents != originalContents {
   621  			t.Error("output is different from the original file")
   622  		}
   623  	})
   624  }
   625  
   626  func TestPutLargeFile(t *testing.T) {
   627  	sourceDir, err := os.Getwd()
   628  	if err != nil {
   629  		t.Fatal(err)
   630  	}
   631  
   632  	runDBTest(t, func(dbt *DBTest) {
   633  		dbt.mustExec("rm @~/test_put_largefile")
   634  		putQuery := fmt.Sprintf("put file://%v/test_data/largefile.txt @%v", sourceDir, "~/test_put_largefile")
   635  		sqlText := strings.ReplaceAll(putQuery, "\\", "\\\\")
   636  		dbt.mustExec(sqlText)
   637  		defer dbt.mustExec("rm @~/test_put_largefile")
   638  		rows := dbt.mustQuery("ls @~/test_put_largefile")
   639  		defer rows.Close()
   640  		var file, s1, s2, s3 string
   641  		if rows.Next() {
   642  			if err := rows.Scan(&file, &s1, &s2, &s3); err != nil {
   643  				t.Fatal(err)
   644  			}
   645  		}
   646  
   647  		if !strings.Contains(file, "largefile.txt.gz") {
   648  			t.Fatalf("should contain file. got: %v", file)
   649  		}
   650  
   651  	})
   652  }
   653  
   654  func TestPutGetMaxLOBSize(t *testing.T) {
   655  	// the LOB sizes to be tested
   656  	testCases := [5]int{smallSize, originSize, mediumSize, largeSize, maxLOBSize}
   657  
   658  	runDBTest(t, func(dbt *DBTest) {
   659  		for _, tc := range testCases {
   660  			// create the data file
   661  			tmpDir := t.TempDir()
   662  			fname := filepath.Join(tmpDir, "test_put_get.txt.gz")
   663  			tableName := randomString(5)
   664  			originalContents := fmt.Sprintf("%v,%s,%v\n", randomString(tc), randomString(tc), rand.Intn(100000))
   665  
   666  			var b bytes.Buffer
   667  			gzw := gzip.NewWriter(&b)
   668  			gzw.Write([]byte(originalContents))
   669  			gzw.Close()
   670  			err := os.WriteFile(fname, b.Bytes(), readWriteFileMode)
   671  			assertNilF(t, err, "could not write to gzip file")
   672  
   673  			dbt.mustExec(fmt.Sprintf("create or replace table %s (c1 varchar, c2 varchar(%v), c3 int)", tableName, tc))
   674  			defer dbt.mustExec("drop table " + tableName)
   675  			fileStream, err := os.Open(fname)
   676  			assertNilF(t, err)
   677  			defer fileStream.Close()
   678  
   679  			// test PUT command
   680  			var sqlText string
   681  			var rows *RowsExtended
   682  			sql := "put 'file://%v' @%%%v auto_compress=true parallel=30"
   683  			sqlText = fmt.Sprintf(
   684  				sql, strings.ReplaceAll(fname, "\\", "\\\\"), tableName)
   685  			rows = dbt.mustQuery(sqlText)
   686  			defer rows.Close()
   687  
   688  			var s0, s1, s2, s3, s4, s5, s6, s7 string
   689  			assertTrueF(t, rows.Next(), "expected new rows")
   690  			err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7)
   691  			assertNilF(t, err)
   692  			assertEqualF(t, s6, uploaded.String(), fmt.Sprintf("expected %v, got: %v", uploaded, s6))
   693  			assertNilF(t, err)
   694  
   695  			// check file is PUT
   696  			dbt.mustQueryAssertCount("ls @%"+tableName, 1)
   697  
   698  			dbt.mustExec("copy into " + tableName)
   699  			dbt.mustExec("rm @%" + tableName)
   700  			dbt.mustQueryAssertCount("ls @%"+tableName, 0)
   701  
   702  			dbt.mustExec(fmt.Sprintf(`copy into @%%%v from %v file_format=(type=csv
   703  			compression='gzip')`, tableName, tableName))
   704  
   705  			// test GET command
   706  			sql = fmt.Sprintf("get @%%%v 'file://%v'", tableName, tmpDir)
   707  			sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
   708  			rows2 := dbt.mustQuery(sqlText)
   709  			defer rows2.Close()
   710  			for rows2.Next() {
   711  				err = rows2.Scan(&s0, &s1, &s2, &s3)
   712  				assertNilE(t, err)
   713  				assertTrueF(t, strings.HasPrefix(s0, "data_"), "a file was not downloaded by GET")
   714  				assertEqualE(t, s2, "DOWNLOADED", "did not return DOWNLOADED status")
   715  				assertEqualE(t, s3, "", fmt.Sprintf("returned %v", s3))
   716  			}
   717  
   718  			// verify the content in the file
   719  			files, err := filepath.Glob(filepath.Join(tmpDir, "data_*"))
   720  			assertNilF(t, err)
   721  
   722  			fileName := files[0]
   723  			f, err := os.Open(fileName)
   724  			assertNilE(t, err)
   725  
   726  			defer f.Close()
   727  			gz, err := gzip.NewReader(f)
   728  			assertNilE(t, err)
   729  
   730  			defer gz.Close()
   731  			var contents string
   732  			for {
   733  				c := make([]byte, defaultChunkBufferSize)
   734  				if n, err := gz.Read(c); err != nil {
   735  					if err == io.EOF {
   736  						contents = contents + string(c[:n])
   737  						break
   738  					}
   739  					t.Error(err)
   740  				} else {
   741  					contents = contents + string(c[:n])
   742  				}
   743  			}
   744  			assertEqualE(t, contents, originalContents, "output is different from the original file")
   745  		}
   746  	})
   747  }