github.com/snowflakedb/gosnowflake@v1.9.0/put_get_with_aws_test.go (about)

     1  // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"compress/gzip"
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"net/url"
    13  	"os"
    14  	"path/filepath"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
    20  	"github.com/aws/aws-sdk-go-v2/service/s3"
    21  )
    22  
    23  func TestLoadS3(t *testing.T) {
    24  	if runningOnGithubAction() && !runningOnAWS() {
    25  		t.Skip("skipping non aws environment")
    26  	}
    27  	runDBTest(t, func(dbt *DBTest) {
    28  		data, err := createTestData(dbt)
    29  		if err != nil {
    30  			t.Skip("snowflake admin account not accessible")
    31  		}
    32  		defer cleanupPut(dbt, data)
    33  		dbt.mustExec("use warehouse " + data.warehouse)
    34  		dbt.mustExec("use schema " + data.database + ".gotesting_schema")
    35  		execQuery := `create or replace table tweets(created_at timestamp,
    36  			id number, id_str string, text string, source string,
    37  			in_reply_to_status_id number, in_reply_to_status_id_str string,
    38  			in_reply_to_user_id number, in_reply_to_user_id_str string,
    39  			in_reply_to_screen_name string, user__id number, user__id_str string,
    40  			user__name string, user__screen_name string, user__location string,
    41  			user__description string, user__url string,
    42  			user__entities__description__urls string, user__protected string,
    43  			user__followers_count number, user__friends_count number,
    44  			user__listed_count number, user__created_at timestamp,
    45  			user__favourites_count number, user__utc_offset number,
    46  			user__time_zone string, user__geo_enabled string,
    47  			user__verified string, user__statuses_count number, user__lang string,
    48  			user__contributors_enabled string, user__is_translator string,
    49  			user__profile_background_color string,
    50  			user__profile_background_image_url string,
    51  			user__profile_background_image_url_https string,
    52  			user__profile_background_tile string, user__profile_image_url string,
    53  			user__profile_image_url_https string, user__profile_link_color string,
    54  			user__profile_sidebar_border_color string,
    55  			user__profile_sidebar_fill_color string, user__profile_text_color string,
    56  			user__profile_use_background_image string, user__default_profile string,
    57  			user__default_profile_image string, user__following string,
    58  			user__follow_request_sent string, user__notifications string,
    59  			geo string, coordinates string, place string, contributors string,
    60  			retweet_count number, favorite_count number, entities__hashtags string,
    61  			entities__symbols string, entities__urls string,
    62  			entities__user_mentions string, favorited string, retweeted string,
    63  			lang string)`
    64  		dbt.mustExec(execQuery)
    65  		defer dbt.mustExec("drop table if exists tweets")
    66  		dbt.mustQueryAssertCount("ls @%tweets", 0)
    67  
    68  		rows := dbt.mustQuery(fmt.Sprintf(`copy into tweets from
    69  			s3://sfc-eng-data/twitter/O1k/tweets/ credentials=(AWS_KEY_ID='%v'
    70  			AWS_SECRET_KEY='%v') file_format=(skip_header=1 null_if=('')
    71  			field_optionally_enclosed_by='\"')`,
    72  			data.awsAccessKeyID, data.awsSecretAccessKey))
    73  		defer rows.Close()
    74  		var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string
    75  		cnt := 0
    76  		for rows.Next() {
    77  			rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9)
    78  			cnt++
    79  		}
    80  		if cnt != 1 {
    81  			t.Fatal("copy into tweets did not set row count to 1")
    82  		}
    83  		if s0 != "s3://sfc-eng-data/twitter/O1k/tweets/1.csv.gz" {
    84  			t.Fatalf("got %v as file", s0)
    85  		}
    86  	})
    87  }
    88  
    89  func TestPutWithInvalidToken(t *testing.T) {
    90  	runSnowflakeConnTest(t, func(sct *SCTest) {
    91  		if !runningOnAWS() {
    92  			t.Skip("skipping non aws environment")
    93  		}
    94  		tmpDir := t.TempDir()
    95  		fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz")
    96  		originalContents := "123,test1\n456,test2\n"
    97  
    98  		var b bytes.Buffer
    99  		gzw := gzip.NewWriter(&b)
   100  		gzw.Write([]byte(originalContents))
   101  		gzw.Close()
   102  		if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil {
   103  			t.Fatal("could not write to gzip file")
   104  		}
   105  
   106  		tableName := randomString(5)
   107  		sct.mustExec("create or replace table "+tableName+" (a int, b string)", nil)
   108  		defer sct.mustExec("drop table "+tableName, nil)
   109  
   110  		jsonBody, err := json.Marshal(execRequest{
   111  			SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName),
   112  		})
   113  		if err != nil {
   114  			t.Error(err)
   115  		}
   116  		headers := getHeaders()
   117  		headers[httpHeaderAccept] = headerContentTypeApplicationJSON
   118  		data, err := sct.sc.rest.FuncPostQuery(
   119  			sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody,
   120  			sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg)
   121  		if err != nil {
   122  			t.Fatal(err)
   123  		}
   124  
   125  		s3Util := new(snowflakeS3Client)
   126  		s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false)
   127  		if err != nil {
   128  			t.Error(err)
   129  		}
   130  		client := s3Cli.(*s3.Client)
   131  
   132  		s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location)
   133  		if err != nil {
   134  			t.Error(err)
   135  		}
   136  		s3Path := s3Loc.s3Path + baseName(fname) + ".gz"
   137  
   138  		f, err := os.Open(fname)
   139  		if err != nil {
   140  			t.Error(err)
   141  		}
   142  		defer f.Close()
   143  		uploader := manager.NewUploader(client)
   144  		if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   145  			Bucket: &s3Loc.bucketName,
   146  			Key:    &s3Path,
   147  			Body:   f,
   148  		}); err != nil {
   149  			t.Fatal(err)
   150  		}
   151  
   152  		parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/"
   153  		if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   154  			Bucket: &s3Loc.bucketName,
   155  			Key:    &parentPath,
   156  			Body:   f,
   157  		}); err == nil {
   158  			t.Fatal("should have failed attempting to put file in parent path")
   159  		}
   160  
   161  		info := execResponseStageInfo{
   162  			Creds: execResponseCredentials{
   163  				AwsID:        data.Data.StageInfo.Creds.AwsID,
   164  				AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey,
   165  			},
   166  		}
   167  		s3Cli, err = s3Util.createClient(&info, false)
   168  		if err != nil {
   169  			t.Error(err)
   170  		}
   171  		client = s3Cli.(*s3.Client)
   172  
   173  		uploader = manager.NewUploader(client)
   174  		if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{
   175  			Bucket: &s3Loc.bucketName,
   176  			Key:    &s3Path,
   177  			Body:   f,
   178  		}); err == nil {
   179  			t.Fatal("should have failed attempting to put with missing aws token")
   180  		}
   181  	})
   182  }
   183  
   184  func TestPretendToPutButList(t *testing.T) {
   185  	if runningOnGithubAction() && !runningOnAWS() {
   186  		t.Skip("skipping non aws environment")
   187  	}
   188  	tmpDir := t.TempDir()
   189  	fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz")
   190  	originalContents := "123,test1\n456,test2\n"
   191  
   192  	var b bytes.Buffer
   193  	gzw := gzip.NewWriter(&b)
   194  	gzw.Write([]byte(originalContents))
   195  	gzw.Close()
   196  	if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil {
   197  		t.Fatal("could not write to gzip file")
   198  	}
   199  
   200  	runSnowflakeConnTest(t, func(sct *SCTest) {
   201  		tableName := randomString(5)
   202  		sct.mustExec("create or replace table "+tableName+
   203  			" (a int, b string)", nil)
   204  		defer sct.sc.Exec("drop table "+tableName, nil)
   205  
   206  		jsonBody, err := json.Marshal(execRequest{
   207  			SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName),
   208  		})
   209  		if err != nil {
   210  			t.Error(err)
   211  		}
   212  		headers := getHeaders()
   213  		headers[httpHeaderAccept] = headerContentTypeApplicationJSON
   214  		data, err := sct.sc.rest.FuncPostQuery(
   215  			sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody,
   216  			sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg)
   217  		if err != nil {
   218  			t.Fatal(err)
   219  		}
   220  
   221  		s3Util := new(snowflakeS3Client)
   222  		s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false)
   223  		if err != nil {
   224  			t.Error(err)
   225  		}
   226  		client := s3Cli.(*s3.Client)
   227  		if _, err = client.ListBuckets(context.Background(),
   228  			&s3.ListBucketsInput{}); err == nil {
   229  			t.Fatal("list buckets should fail")
   230  		}
   231  	})
   232  }
   233  
   234  func TestPutGetAWSStage(t *testing.T) {
   235  	if runningOnGithubAction() && !runningOnAWS() {
   236  		t.Skip("skipping non aws environment")
   237  	}
   238  
   239  	tmpDir := t.TempDir()
   240  	name := "test_put_get.txt.gz"
   241  	fname := filepath.Join(tmpDir, name)
   242  	originalContents := "123,test1\n456,test2\n"
   243  	stageName := "test_put_get_stage_" + randomString(5)
   244  
   245  	var b bytes.Buffer
   246  	gzw := gzip.NewWriter(&b)
   247  	gzw.Write([]byte(originalContents))
   248  	gzw.Close()
   249  	if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil {
   250  		t.Fatal("could not write to gzip file")
   251  	}
   252  
   253  	runDBTest(t, func(dbt *DBTest) {
   254  		var createStageQuery string
   255  		keyID, secretKey, _, err := getAWSCredentials()
   256  		if err != nil {
   257  			t.Skip("snowflake admin account not accessible")
   258  		}
   259  		createStageQuery = fmt.Sprintf(createStageStmt,
   260  			stageName,
   261  			"s3://"+stageName,
   262  			fmt.Sprintf("AWS_KEY_ID='%v' AWS_SECRET_KEY='%v'", keyID, secretKey))
   263  		dbt.mustExec(createStageQuery)
   264  
   265  		defer dbt.mustExec("DROP STAGE IF EXISTS " + stageName)
   266  
   267  		sql := "put 'file://%v' @~/%v auto_compress=false"
   268  		sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), stageName)
   269  		rows := dbt.mustQuery(sqlText)
   270  		defer rows.Close()
   271  
   272  		var s0, s1, s2, s3, s4, s5, s6, s7 string
   273  		if rows.Next() {
   274  			if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil {
   275  				t.Fatal(err)
   276  			}
   277  		}
   278  		if s6 != uploaded.String() {
   279  			t.Fatalf("expected %v, got: %v", uploaded, s6)
   280  		}
   281  
   282  		sql = fmt.Sprintf("get @~/%v 'file://%v'", stageName, tmpDir)
   283  		sqlText = strings.ReplaceAll(sql, "\\", "\\\\")
   284  		rows = dbt.mustQuery(sqlText)
   285  		defer rows.Close()
   286  		for rows.Next() {
   287  			if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil {
   288  				t.Error(err)
   289  			}
   290  
   291  			if strings.Compare(s0, name) != 0 {
   292  				t.Error("a file was not downloaded by GET")
   293  			}
   294  			if v, err := strconv.Atoi(s1); err != nil || v != 41 {
   295  				t.Error("did not return the right file size")
   296  			}
   297  			if s2 != "DOWNLOADED" {
   298  				t.Error("did not return DOWNLOADED status")
   299  			}
   300  			if s3 != "" {
   301  				t.Errorf("returned %v", s3)
   302  			}
   303  		}
   304  
   305  		files, err := filepath.Glob(filepath.Join(tmpDir, "*"))
   306  		if err != nil {
   307  			t.Fatal(err)
   308  		}
   309  		fileName := files[0]
   310  		f, err := os.Open(fileName)
   311  		if err != nil {
   312  			t.Error(err)
   313  		}
   314  		defer f.Close()
   315  		gz, err := gzip.NewReader(f)
   316  		if err != nil {
   317  			t.Error(err)
   318  		}
   319  		var contents string
   320  		for {
   321  			c := make([]byte, defaultChunkBufferSize)
   322  			if n, err := gz.Read(c); err != nil {
   323  				if err == io.EOF {
   324  					contents = contents + string(c[:n])
   325  					break
   326  				}
   327  				t.Error(err)
   328  			} else {
   329  				contents = contents + string(c[:n])
   330  			}
   331  		}
   332  
   333  		if contents != originalContents {
   334  			t.Error("output is different from the original file")
   335  		}
   336  	})
   337  }