github.com/snowflakedb/gosnowflake@v1.9.0/put_get_with_aws_test.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "compress/gzip" 8 "context" 9 "encoding/json" 10 "fmt" 11 "io" 12 "net/url" 13 "os" 14 "path/filepath" 15 "strconv" 16 "strings" 17 "testing" 18 19 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" 20 "github.com/aws/aws-sdk-go-v2/service/s3" 21 ) 22 23 func TestLoadS3(t *testing.T) { 24 if runningOnGithubAction() && !runningOnAWS() { 25 t.Skip("skipping non aws environment") 26 } 27 runDBTest(t, func(dbt *DBTest) { 28 data, err := createTestData(dbt) 29 if err != nil { 30 t.Skip("snowflake admin account not accessible") 31 } 32 defer cleanupPut(dbt, data) 33 dbt.mustExec("use warehouse " + data.warehouse) 34 dbt.mustExec("use schema " + data.database + ".gotesting_schema") 35 execQuery := `create or replace table tweets(created_at timestamp, 36 id number, id_str string, text string, source string, 37 in_reply_to_status_id number, in_reply_to_status_id_str string, 38 in_reply_to_user_id number, in_reply_to_user_id_str string, 39 in_reply_to_screen_name string, user__id number, user__id_str string, 40 user__name string, user__screen_name string, user__location string, 41 user__description string, user__url string, 42 user__entities__description__urls string, user__protected string, 43 user__followers_count number, user__friends_count number, 44 user__listed_count number, user__created_at timestamp, 45 user__favourites_count number, user__utc_offset number, 46 user__time_zone string, user__geo_enabled string, 47 user__verified string, user__statuses_count number, user__lang string, 48 user__contributors_enabled string, user__is_translator string, 49 user__profile_background_color string, 50 user__profile_background_image_url string, 51 user__profile_background_image_url_https string, 52 user__profile_background_tile string, user__profile_image_url string, 53 user__profile_image_url_https string, user__profile_link_color string, 54 user__profile_sidebar_border_color string, 55 user__profile_sidebar_fill_color string, user__profile_text_color string, 56 user__profile_use_background_image string, user__default_profile string, 57 user__default_profile_image string, user__following string, 58 user__follow_request_sent string, user__notifications string, 59 geo string, coordinates string, place string, contributors string, 60 retweet_count number, favorite_count number, entities__hashtags string, 61 entities__symbols string, entities__urls string, 62 entities__user_mentions string, favorited string, retweeted string, 63 lang string)` 64 dbt.mustExec(execQuery) 65 defer dbt.mustExec("drop table if exists tweets") 66 dbt.mustQueryAssertCount("ls @%tweets", 0) 67 68 rows := dbt.mustQuery(fmt.Sprintf(`copy into tweets from 69 s3://sfc-eng-data/twitter/O1k/tweets/ credentials=(AWS_KEY_ID='%v' 70 AWS_SECRET_KEY='%v') file_format=(skip_header=1 null_if=('') 71 field_optionally_enclosed_by='\"')`, 72 data.awsAccessKeyID, data.awsSecretAccessKey)) 73 defer rows.Close() 74 var s0, s1, s2, s3, s4, s5, s6, s7, s8, s9 string 75 cnt := 0 76 for rows.Next() { 77 rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8, &s9) 78 cnt++ 79 } 80 if cnt != 1 { 81 t.Fatal("copy into tweets did not set row count to 1") 82 } 83 if s0 != "s3://sfc-eng-data/twitter/O1k/tweets/1.csv.gz" { 84 t.Fatalf("got %v as file", s0) 85 } 86 }) 87 } 88 89 func TestPutWithInvalidToken(t *testing.T) { 90 runSnowflakeConnTest(t, func(sct *SCTest) { 91 if !runningOnAWS() { 92 t.Skip("skipping non aws environment") 93 } 94 tmpDir := t.TempDir() 95 fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") 96 originalContents := "123,test1\n456,test2\n" 97 98 var b bytes.Buffer 99 gzw := gzip.NewWriter(&b) 100 gzw.Write([]byte(originalContents)) 101 gzw.Close() 102 if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { 103 t.Fatal("could not write to gzip file") 104 } 105 106 tableName := randomString(5) 107 sct.mustExec("create or replace table "+tableName+" (a int, b string)", nil) 108 defer sct.mustExec("drop table "+tableName, nil) 109 110 jsonBody, err := json.Marshal(execRequest{ 111 SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), 112 }) 113 if err != nil { 114 t.Error(err) 115 } 116 headers := getHeaders() 117 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 118 data, err := sct.sc.rest.FuncPostQuery( 119 sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, 120 sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) 121 if err != nil { 122 t.Fatal(err) 123 } 124 125 s3Util := new(snowflakeS3Client) 126 s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) 127 if err != nil { 128 t.Error(err) 129 } 130 client := s3Cli.(*s3.Client) 131 132 s3Loc, err := s3Util.extractBucketNameAndPath(data.Data.StageInfo.Location) 133 if err != nil { 134 t.Error(err) 135 } 136 s3Path := s3Loc.s3Path + baseName(fname) + ".gz" 137 138 f, err := os.Open(fname) 139 if err != nil { 140 t.Error(err) 141 } 142 defer f.Close() 143 uploader := manager.NewUploader(client) 144 if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 145 Bucket: &s3Loc.bucketName, 146 Key: &s3Path, 147 Body: f, 148 }); err != nil { 149 t.Fatal(err) 150 } 151 152 parentPath := filepath.Dir(filepath.Dir(s3Path)) + "/" 153 if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 154 Bucket: &s3Loc.bucketName, 155 Key: &parentPath, 156 Body: f, 157 }); err == nil { 158 t.Fatal("should have failed attempting to put file in parent path") 159 } 160 161 info := execResponseStageInfo{ 162 Creds: execResponseCredentials{ 163 AwsID: data.Data.StageInfo.Creds.AwsID, 164 AwsSecretKey: data.Data.StageInfo.Creds.AwsSecretKey, 165 }, 166 } 167 s3Cli, err = s3Util.createClient(&info, false) 168 if err != nil { 169 t.Error(err) 170 } 171 client = s3Cli.(*s3.Client) 172 173 uploader = manager.NewUploader(client) 174 if _, err = uploader.Upload(context.Background(), &s3.PutObjectInput{ 175 Bucket: &s3Loc.bucketName, 176 Key: &s3Path, 177 Body: f, 178 }); err == nil { 179 t.Fatal("should have failed attempting to put with missing aws token") 180 } 181 }) 182 } 183 184 func TestPretendToPutButList(t *testing.T) { 185 if runningOnGithubAction() && !runningOnAWS() { 186 t.Skip("skipping non aws environment") 187 } 188 tmpDir := t.TempDir() 189 fname := filepath.Join(tmpDir, "test_put_get_with_aws.txt.gz") 190 originalContents := "123,test1\n456,test2\n" 191 192 var b bytes.Buffer 193 gzw := gzip.NewWriter(&b) 194 gzw.Write([]byte(originalContents)) 195 gzw.Close() 196 if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { 197 t.Fatal("could not write to gzip file") 198 } 199 200 runSnowflakeConnTest(t, func(sct *SCTest) { 201 tableName := randomString(5) 202 sct.mustExec("create or replace table "+tableName+ 203 " (a int, b string)", nil) 204 defer sct.sc.Exec("drop table "+tableName, nil) 205 206 jsonBody, err := json.Marshal(execRequest{ 207 SQLText: fmt.Sprintf("put 'file://%v' @%%%v", fname, tableName), 208 }) 209 if err != nil { 210 t.Error(err) 211 } 212 headers := getHeaders() 213 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 214 data, err := sct.sc.rest.FuncPostQuery( 215 sct.sc.ctx, sct.sc.rest, &url.Values{}, headers, jsonBody, 216 sct.sc.rest.RequestTimeout, getOrGenerateRequestIDFromContext(sct.sc.ctx), sct.sc.cfg) 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 s3Util := new(snowflakeS3Client) 222 s3Cli, err := s3Util.createClient(&data.Data.StageInfo, false) 223 if err != nil { 224 t.Error(err) 225 } 226 client := s3Cli.(*s3.Client) 227 if _, err = client.ListBuckets(context.Background(), 228 &s3.ListBucketsInput{}); err == nil { 229 t.Fatal("list buckets should fail") 230 } 231 }) 232 } 233 234 func TestPutGetAWSStage(t *testing.T) { 235 if runningOnGithubAction() && !runningOnAWS() { 236 t.Skip("skipping non aws environment") 237 } 238 239 tmpDir := t.TempDir() 240 name := "test_put_get.txt.gz" 241 fname := filepath.Join(tmpDir, name) 242 originalContents := "123,test1\n456,test2\n" 243 stageName := "test_put_get_stage_" + randomString(5) 244 245 var b bytes.Buffer 246 gzw := gzip.NewWriter(&b) 247 gzw.Write([]byte(originalContents)) 248 gzw.Close() 249 if err := os.WriteFile(fname, b.Bytes(), readWriteFileMode); err != nil { 250 t.Fatal("could not write to gzip file") 251 } 252 253 runDBTest(t, func(dbt *DBTest) { 254 var createStageQuery string 255 keyID, secretKey, _, err := getAWSCredentials() 256 if err != nil { 257 t.Skip("snowflake admin account not accessible") 258 } 259 createStageQuery = fmt.Sprintf(createStageStmt, 260 stageName, 261 "s3://"+stageName, 262 fmt.Sprintf("AWS_KEY_ID='%v' AWS_SECRET_KEY='%v'", keyID, secretKey)) 263 dbt.mustExec(createStageQuery) 264 265 defer dbt.mustExec("DROP STAGE IF EXISTS " + stageName) 266 267 sql := "put 'file://%v' @~/%v auto_compress=false" 268 sqlText := fmt.Sprintf(sql, strings.ReplaceAll(fname, "\\", "\\\\"), stageName) 269 rows := dbt.mustQuery(sqlText) 270 defer rows.Close() 271 272 var s0, s1, s2, s3, s4, s5, s6, s7 string 273 if rows.Next() { 274 if err = rows.Scan(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7); err != nil { 275 t.Fatal(err) 276 } 277 } 278 if s6 != uploaded.String() { 279 t.Fatalf("expected %v, got: %v", uploaded, s6) 280 } 281 282 sql = fmt.Sprintf("get @~/%v 'file://%v'", stageName, tmpDir) 283 sqlText = strings.ReplaceAll(sql, "\\", "\\\\") 284 rows = dbt.mustQuery(sqlText) 285 defer rows.Close() 286 for rows.Next() { 287 if err = rows.Scan(&s0, &s1, &s2, &s3); err != nil { 288 t.Error(err) 289 } 290 291 if strings.Compare(s0, name) != 0 { 292 t.Error("a file was not downloaded by GET") 293 } 294 if v, err := strconv.Atoi(s1); err != nil || v != 41 { 295 t.Error("did not return the right file size") 296 } 297 if s2 != "DOWNLOADED" { 298 t.Error("did not return DOWNLOADED status") 299 } 300 if s3 != "" { 301 t.Errorf("returned %v", s3) 302 } 303 } 304 305 files, err := filepath.Glob(filepath.Join(tmpDir, "*")) 306 if err != nil { 307 t.Fatal(err) 308 } 309 fileName := files[0] 310 f, err := os.Open(fileName) 311 if err != nil { 312 t.Error(err) 313 } 314 defer f.Close() 315 gz, err := gzip.NewReader(f) 316 if err != nil { 317 t.Error(err) 318 } 319 var contents string 320 for { 321 c := make([]byte, defaultChunkBufferSize) 322 if n, err := gz.Read(c); err != nil { 323 if err == io.EOF { 324 contents = contents + string(c[:n]) 325 break 326 } 327 t.Error(err) 328 } else { 329 contents = contents + string(c[:n]) 330 } 331 } 332 333 if contents != originalContents { 334 t.Error("output is different from the original file") 335 } 336 }) 337 }