github.com/snowflakedb/gosnowflake@v1.9.0/put_get_user_stage_test.go (about)

     1  package gosnowflake
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strconv"
     9  	"strings"
    10  	"testing"
    11  )
    12  
    13  func TestPutGetFileSmallDataViaUserStage(t *testing.T) {
    14  	if os.Getenv("AWS_ACCESS_KEY_ID") == "" {
    15  		t.Skip("this test requires to change the internal parameter")
    16  	}
    17  	putGetUserStage(t, 5, 1, false)
    18  }
    19  
    20  func TestPutGetStreamSmallDataViaUserStage(t *testing.T) {
    21  	if os.Getenv("AWS_ACCESS_KEY_ID") == "" {
    22  		t.Skip("this test requires to change the internal parameter")
    23  	}
    24  	putGetUserStage(t, 1, 1, true)
    25  }
    26  
    27  func putGetUserStage(t *testing.T, numberOfFiles int, numberOfLines int, isStream bool) {
    28  	if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" {
    29  		t.Fatal("no aws secret access key found")
    30  	}
    31  	tmpDir, err := generateKLinesOfNFiles(numberOfLines, numberOfFiles, false, t.TempDir())
    32  	if err != nil {
    33  		t.Error(err)
    34  	}
    35  	var files string
    36  	if isStream {
    37  		list, err := os.ReadDir(tmpDir)
    38  		if err != nil {
    39  			t.Error(err)
    40  		}
    41  		file := list[0].Name()
    42  		files = filepath.Join(tmpDir, file)
    43  	} else {
    44  		files = filepath.Join(tmpDir, "file*")
    45  	}
    46  
    47  	runDBTest(t, func(dbt *DBTest) {
    48  		stageName := fmt.Sprintf("%v_stage_%v_%v", dbname, numberOfFiles, numberOfLines)
    49  		sqlText := `create or replace table %v (aa int, dt date, ts timestamp,
    50  			tsltz timestamp_ltz, tsntz timestamp_ntz, tstz timestamp_tz,
    51  			pct float, ratio number(6,2))`
    52  		dbt.mustExec(fmt.Sprintf(sqlText, dbname))
    53  		userBucket := os.Getenv("SF_AWS_USER_BUCKET")
    54  		if userBucket == "" {
    55  			userBucket = fmt.Sprintf("sfc-eng-regression/%v/reg", username)
    56  		}
    57  		sqlText = `create or replace stage %v url='s3://%v}/%v-%v-%v'
    58  			credentials = (AWS_KEY_ID='%v' AWS_SECRET_KEY='%v')`
    59  		dbt.mustExec(fmt.Sprintf(sqlText, stageName, userBucket, stageName,
    60  			numberOfFiles, numberOfLines, os.Getenv("AWS_ACCESS_KEY_ID"),
    61  			os.Getenv("AWS_SECRET_ACCESS_KEY")))
    62  
    63  		dbt.mustExec("alter session set disable_put_and_get_on_external_stage = false")
    64  		dbt.mustExec("rm @" + stageName)
    65  		var fs *os.File
    66  		if isStream {
    67  			fs, _ = os.Open(files)
    68  			dbt.mustExecContext(WithFileStream(context.Background(), fs),
    69  				fmt.Sprintf("put 'file://%v' @%v", strings.ReplaceAll(
    70  					files, "\\", "\\\\"), stageName))
    71  		} else {
    72  			dbt.mustExec(fmt.Sprintf("put 'file://%v' @%v ", strings.ReplaceAll(files, "\\", "\\\\"), stageName))
    73  		}
    74  		defer func() {
    75  			if isStream {
    76  				fs.Close()
    77  			}
    78  			dbt.mustExec("rm @" + stageName)
    79  			dbt.mustExec("drop stage if exists " + stageName)
    80  			dbt.mustExec("drop table if exists " + dbname)
    81  		}()
    82  		dbt.mustExec(fmt.Sprintf("copy into %v from @%v", dbname, stageName))
    83  
    84  		rows := dbt.mustQuery("select count(*) from " + dbname)
    85  		defer rows.Close()
    86  		var cnt string
    87  		if rows.Next() {
    88  			rows.Scan(&cnt)
    89  		}
    90  		count, err := strconv.Atoi(cnt)
    91  		if err != nil {
    92  			t.Error(err)
    93  		}
    94  		if count != numberOfFiles*numberOfLines {
    95  			t.Errorf("count did not match expected number. count: %v, expected: %v", count, numberOfFiles*numberOfLines)
    96  		}
    97  	})
    98  }
    99  
   100  func TestPutLoadFromUserStage(t *testing.T) {
   101  	runDBTest(t, func(dbt *DBTest) {
   102  		data, err := createTestData(dbt)
   103  		if err != nil {
   104  			t.Skip("snowflake admin account not accessible")
   105  		}
   106  		defer cleanupPut(dbt, data)
   107  		dbt.mustExec("alter session set DISABLE_PUT_AND_GET_ON_EXTERNAL_STAGE=false")
   108  		dbt.mustExec("use warehouse " + data.warehouse)
   109  		dbt.mustExec("use schema " + data.database + ".gotesting_schema")
   110  
   111  		execQuery := fmt.Sprintf(
   112  			`create or replace stage %v url = 's3://%v/%v' credentials = (
   113  			AWS_KEY_ID='%v' AWS_SECRET_KEY='%v')`,
   114  			data.stage, data.userBucket, data.stage,
   115  			data.awsAccessKeyID, data.awsSecretAccessKey)
   116  		dbt.mustExec(execQuery)
   117  
   118  		execQuery = `create or replace table gotest_putget_t2 (c1 STRING,
   119  			c2 STRING, c3 STRING,c4 STRING, c5 STRING, c6 STRING, c7 STRING,
   120  			c8 STRING, c9 STRING)`
   121  		dbt.mustExec(execQuery)
   122  		defer dbt.mustExec("drop table if exists gotest_putget_t2")
   123  		defer dbt.mustExec("drop stage if exists " + data.stage)
   124  
   125  		execQuery = fmt.Sprintf("put file://%v/test_data/orders_10*.csv @%v",
   126  			data.dir, data.stage)
   127  		dbt.mustExec(execQuery)
   128  		dbt.mustQueryAssertCount("ls @%gotest_putget_t2", 0)
   129  
   130  		rows := dbt.mustQuery(fmt.Sprintf(`copy into gotest_putget_t2 from @%v
   131  			file_format = (field_delimiter = '|' error_on_column_count_mismatch
   132  			=false) purge=true`, data.stage))
   133  		defer rows.Close()
   134  		var s0, s1, s2, s3, s4, s5 string
   135  		var s6, s7, s8, s9 interface{}
   136  		orders100 := fmt.Sprintf("s3://%v/%v/orders_100.csv.gz",
   137  			data.userBucket, data.stage)
   138  		orders101 := fmt.Sprintf("s3://%v/%v/orders_101.csv.gz",
   139  			data.userBucket, data.stage)
   140  		for rows.Next() {
   141  			rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)
   142  			if s0 != orders100 && s0 != orders101 {
   143  				t.Fatalf("copy did not load orders files. got: %v", s0)
   144  			}
   145  		}
   146  		dbt.mustQueryAssertCount(fmt.Sprintf("ls @%v", data.stage), 0)
   147  	})
   148  }