github.com/snowflakedb/gosnowflake@v1.9.0/connection_test.go (about) 1 // Copyright (c) 2019-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "database/sql" 8 "database/sql/driver" 9 "encoding/json" 10 "errors" 11 "fmt" 12 "io" 13 "net/http" 14 "net/url" 15 "os" 16 "strings" 17 "sync" 18 "testing" 19 "time" 20 ) 21 22 const ( 23 serviceNameStub = "SV" 24 serviceNameAppend = "a" 25 ) 26 27 func TestInvalidConnection(t *testing.T) { 28 db := openDB(t) 29 if err := db.Close(); err != nil { 30 t.Error("should not cause error in Close") 31 } 32 if err := db.Close(); err != nil { 33 t.Error("should not cause error in the second call of Close") 34 } 35 if _, err := db.ExecContext(context.Background(), "CREATE TABLE OR REPLACE test0(c1 int)"); err == nil { 36 t.Error("should fail to run Exec") 37 } 38 if _, err := db.QueryContext(context.Background(), "SELECT CURRENT_TIMESTAMP()"); err == nil { 39 t.Error("should fail to run Query") 40 } 41 if _, err := db.BeginTx(context.Background(), nil); err == nil { 42 t.Error("should fail to run Begin") 43 } 44 } 45 46 // postQueryMock generates a response based on the X-Snowflake-Service header, 47 // to generate a response with the SERVICE_NAME field appending a character at 48 // the end of the header. This way it could test both the send and receive logic 49 func postQueryMock(_ context.Context, _ *snowflakeRestful, _ *url.Values, 50 headers map[string]string, _ []byte, _ time.Duration, _ UUID, 51 _ *Config) (*execResponse, error) { 52 var serviceName string 53 if serviceHeader, ok := headers[httpHeaderServiceName]; ok { 54 serviceName = serviceHeader + serviceNameAppend 55 } else { 56 serviceName = serviceNameStub 57 } 58 59 dd := &execResponseData{ 60 Parameters: []nameValueParameter{{"SERVICE_NAME", serviceName}}, 61 } 62 return &execResponse{ 63 Data: *dd, 64 Message: "", 65 Code: "0", 66 Success: true, 67 }, nil 68 } 69 70 func TestExecWithEmptyRequestID(t *testing.T) { 71 ctx := WithRequestID(context.Background(), nilUUID) 72 postQueryMock := func(_ context.Context, _ *snowflakeRestful, 73 _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, 74 requestID UUID, _ *Config) (*execResponse, error) { 75 // ensure the same requestID from context is used 76 if len(requestID) == 0 { 77 t.Fatal("requestID is empty") 78 } 79 dd := &execResponseData{} 80 return &execResponse{ 81 Data: *dd, 82 Message: "", 83 Code: "0", 84 Success: true, 85 }, nil 86 } 87 88 sr := &snowflakeRestful{ 89 FuncPostQuery: postQueryMock, 90 } 91 92 sc := &snowflakeConn{ 93 cfg: &Config{Params: map[string]*string{}}, 94 rest: sr, 95 queryContextCache: (&queryContextCache{}).init(), 96 } 97 if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ 98 false /* describeOnly */, nil); err != nil { 99 t.Fatalf("err: %v", err) 100 } 101 } 102 103 func TestGetQueryResultUsesTokenFromTokenAccessor(t *testing.T) { 104 ta := getSimpleTokenAccessor() 105 token := "snowflake-test-token" 106 ta.SetTokens(token, "", 1) 107 funcGetMock := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, 108 headers map[string]string, _ time.Duration) (*http.Response, error) { 109 if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, token) { 110 t.Fatalf("header authorization key is not correct: %v", headers[headerAuthorizationKey]) 111 } 112 dd := &execResponseData{} 113 er := &execResponse{ 114 Data: *dd, 115 Message: "", 116 Code: sessionExpiredCode, 117 Success: true, 118 } 119 ba, err := json.Marshal(er) 120 if err != nil { 121 t.Fatalf("err: %v", err) 122 } 123 return &http.Response{ 124 StatusCode: http.StatusOK, 125 Body: &fakeResponseBody{body: ba}, 126 }, nil 127 } 128 sr := &snowflakeRestful{ 129 FuncGet: funcGetMock, 130 TokenAccessor: ta, 131 } 132 sc := &snowflakeConn{ 133 cfg: &Config{Params: map[string]*string{}}, 134 rest: sr, 135 currentTimeProvider: defaultTimeProvider, 136 } 137 if _, err := sc.getQueryResultResp(context.Background(), ""); err != nil { 138 t.Fatalf("err: %v", err) 139 } 140 } 141 142 func TestExecWithSpecificRequestID(t *testing.T) { 143 origRequestID := NewUUID() 144 ctx := WithRequestID(context.Background(), origRequestID) 145 postQueryMock := func(_ context.Context, _ *snowflakeRestful, 146 _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, 147 requestID UUID, _ *Config) (*execResponse, error) { 148 // ensure the same requestID from context is used 149 if requestID != origRequestID { 150 t.Fatal("requestID doesn't match") 151 } 152 dd := &execResponseData{} 153 return &execResponse{ 154 Data: *dd, 155 Message: "", 156 Code: "0", 157 Success: true, 158 }, nil 159 } 160 161 sr := &snowflakeRestful{ 162 FuncPostQuery: postQueryMock, 163 } 164 165 sc := &snowflakeConn{ 166 cfg: &Config{Params: map[string]*string{}}, 167 rest: sr, 168 queryContextCache: (&queryContextCache{}).init(), 169 } 170 if _, err := sc.exec(ctx, "", false /* noResult */, false, /* isInternal */ 171 false /* describeOnly */, nil); err != nil { 172 t.Fatalf("err: %v", err) 173 } 174 } 175 176 // TestServiceName tests two things: 177 // 1. request header contains X-Snowflake-Service if the cfg parameters 178 // contains SERVICE_NAME 179 // 2. SERVICE_NAME is updated by response payload 180 // Uses interactive postQueryMock that generates a response based on header 181 func TestServiceName(t *testing.T) { 182 sr := &snowflakeRestful{ 183 FuncPostQuery: postQueryMock, 184 } 185 186 sc := &snowflakeConn{ 187 cfg: &Config{Params: map[string]*string{}}, 188 rest: sr, 189 queryContextCache: (&queryContextCache{}).init(), 190 } 191 192 expectServiceName := serviceNameStub 193 for i := 0; i < 5; i++ { 194 sc.exec(context.Background(), "", false, /* noResult */ 195 false /* isInternal */, false /* describeOnly */, nil) 196 if actualServiceName, ok := sc.cfg.Params[serviceName]; ok { 197 if *actualServiceName != expectServiceName { 198 t.Errorf("service name mis-match. expected %v, actual %v", 199 expectServiceName, actualServiceName) 200 } 201 } else { 202 t.Error("No service name in the response") 203 } 204 expectServiceName += serviceNameAppend 205 } 206 } 207 208 var closedSessionCount = 0 209 210 var testTelemetry = &snowflakeTelemetry{ 211 mutex: &sync.Mutex{}, 212 } 213 214 func closeSessionMock(_ context.Context, _ *snowflakeRestful, _ time.Duration) error { 215 closedSessionCount++ 216 return &SnowflakeError{ 217 Number: ErrSessionGone, 218 } 219 } 220 221 func TestCloseIgnoreSessionGone(t *testing.T) { 222 sr := &snowflakeRestful{ 223 FuncCloseSession: closeSessionMock, 224 } 225 sc := &snowflakeConn{ 226 cfg: &Config{Params: map[string]*string{}}, 227 rest: sr, 228 telemetry: testTelemetry, 229 queryContextCache: (&queryContextCache{}).init(), 230 } 231 232 if sc.Close() != nil { 233 t.Error("Close should let go session gone error") 234 } 235 } 236 237 func TestClientSessionPersist(t *testing.T) { 238 sr := &snowflakeRestful{ 239 FuncCloseSession: closeSessionMock, 240 } 241 sc := &snowflakeConn{ 242 cfg: &Config{Params: map[string]*string{}}, 243 rest: sr, 244 telemetry: testTelemetry, 245 } 246 sc.cfg.KeepSessionAlive = true 247 count := closedSessionCount 248 if sc.Close() != nil { 249 t.Error("Connection close should not return error") 250 } 251 if count != closedSessionCount { 252 t.Fatal("close session was called") 253 } 254 } 255 256 func TestFetchResultByQueryID(t *testing.T) { 257 fetchResultByQueryID(t, nil, nil) 258 } 259 260 func TestFetchRunningQueryByID(t *testing.T) { 261 fetchResultByQueryID(t, returnQueryIsRunningStatus, nil) 262 } 263 264 func TestFetchErrorQueryByID(t *testing.T) { 265 fetchResultByQueryID(t, returnQueryIsErrStatus, &SnowflakeError{ 266 Number: ErrQueryReportedError}) 267 } 268 269 func TestFetchMalformedJsonQueryByID(t *testing.T) { 270 expectedErr := errors.New("invalid character '}' after object key") 271 fetchResultByQueryID(t, returnQueryMalformedJSON, expectedErr) 272 } 273 274 func customGetQuery(ctx context.Context, rest *snowflakeRestful, url *url.URL, 275 vals map[string]string, _ time.Duration, jsonStr string) ( 276 *http.Response, error) { 277 if strings.Contains(url.Path, "/monitoring/queries/") { 278 return &http.Response{ 279 StatusCode: http.StatusOK, 280 Body: io.NopCloser(strings.NewReader(jsonStr)), 281 }, nil 282 } 283 return getRestful(ctx, rest, url, vals, rest.RequestTimeout) 284 } 285 286 func returnQueryIsRunningStatus(ctx context.Context, rest *snowflakeRestful, fullURL *url.URL, 287 vals map[string]string, duration time.Duration) (*http.Response, error) { 288 jsonStr := `{"data" : { "queries" : [{"status" : "RUNNING", "state" : 289 "FILE_SET_INITIALIZATION", "errorCode" : "", "errorMessage" : null}] }, 290 "code" : null, "message" : null, "success" : true }` 291 return customGetQuery(ctx, rest, fullURL, vals, duration, jsonStr) 292 } 293 294 func returnQueryIsErrStatus(ctx context.Context, rest *snowflakeRestful, fullURL *url.URL, 295 vals map[string]string, duration time.Duration) (*http.Response, error) { 296 jsonStr := `{"data" : { "queries" : [{"status" : "FAILED_WITH_ERROR", 297 "errorCode" : "", "errorMessage" : ""}] }, "code" : null, "message" : 298 null, "success" : true }` 299 return customGetQuery(ctx, rest, fullURL, vals, duration, jsonStr) 300 } 301 302 func returnQueryMalformedJSON(ctx context.Context, rest *snowflakeRestful, fullURL *url.URL, 303 vals map[string]string, duration time.Duration) (*http.Response, error) { 304 jsonStr := `{"malformedJson"}` 305 return customGetQuery(ctx, rest, fullURL, vals, duration, jsonStr) 306 } 307 308 // this function is going to: 1, create a table, 2, query on this table, 309 // 3, fetch result of query in step 2, mock running status and error status 310 // of that query. 311 func fetchResultByQueryID( 312 t *testing.T, 313 customGet funcGetType, 314 expectedFetchErr error) error { 315 config, err := ParseDSN(dsn) 316 if err != nil { 317 return err 318 } 319 ctx := context.Background() 320 sc, err := buildSnowflakeConn(ctx, *config) 321 if customGet != nil { 322 sc.rest.FuncGet = customGet 323 } 324 if err != nil { 325 return err 326 } 327 if err = authenticateWithConfig(sc); err != nil { 328 return err 329 } 330 331 if _, err = sc.Exec(`create or replace table ut_conn(c1 number, c2 string) 332 as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 333 from table(generator(rowcount => 100)))`, nil); err != nil { 334 t.Fatalf("err: %v", err) 335 } 336 337 rows1, err := sc.QueryContext(ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) 338 if err != nil { 339 t.Fatalf("Query failed: %v", err) 340 } 341 342 qid := rows1.(SnowflakeResult).GetQueryID() 343 newCtx := WithFetchResultByID(ctx, qid) 344 345 rows2, err := sc.QueryContext(newCtx, "", nil) 346 if err != nil { 347 snowflakeErr, ok := err.(*SnowflakeError) 348 if ok && expectedFetchErr != nil { // got expected error number 349 if expectedSnowflakeErr, ok := expectedFetchErr.(*SnowflakeError); ok { 350 if expectedSnowflakeErr.Number == snowflakeErr.Number { 351 return nil 352 } 353 } 354 } else if !ok { // not a SnowflakeError 355 if strings.Contains(err.Error(), expectedFetchErr.Error()) { 356 return nil 357 } 358 } 359 t.Fatalf("Fetch Query Result by ID failed: %v", err) 360 } 361 362 dest := make([]driver.Value, 2) 363 cnt := 0 364 for { 365 if err = rows2.Next(dest); err != nil { 366 if err == io.EOF { 367 break 368 } else { 369 t.Fatalf("unexpected error: %v", err) 370 } 371 } 372 cnt++ 373 } 374 if cnt != 10 { 375 t.Fatalf("rowcount is not expected 10: %v", cnt) 376 } 377 return nil 378 } 379 380 func TestPrivateLink(t *testing.T) { 381 if _, err := buildSnowflakeConn(context.Background(), Config{ 382 Account: "testaccount", 383 User: "testuser", 384 Password: "testpassword", 385 Host: "testaccount.us-east-1.privatelink.snowflakecomputing.com", 386 }); err != nil { 387 t.Error(err) 388 } 389 ocspURL := os.Getenv(cacheServerURLEnv) 390 expectedURL := "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/ocsp_response_cache.json" 391 if ocspURL != expectedURL { 392 t.Errorf("expected: %v, got: %v", expectedURL, ocspURL) 393 } 394 retryURL := os.Getenv(ocspRetryURLEnv) 395 expectedURL = "http://ocsp.testaccount.us-east-1.privatelink.snowflakecomputing.com/retry/%v/%v" 396 if retryURL != expectedURL { 397 t.Errorf("expected: %v, got: %v", expectedURL, retryURL) 398 } 399 } 400 401 func TestGetQueryStatus(t *testing.T) { 402 runSnowflakeConnTest(t, func(sct *SCTest) { 403 sct.mustExec(`create or replace table ut_conn(c1 number, c2 string) 404 as (select seq4() as seq, concat('str',to_varchar(seq)) as str1 405 from table(generator(rowcount => 100)))`, 406 nil) 407 408 rows := sct.mustQueryContext(sct.sc.ctx, "select min(c1) as ms, sum(c1) from ut_conn group by (c1 % 10) order by ms", nil) 409 qid := rows.(SnowflakeResult).GetQueryID() 410 411 // use conn as type holder for SnowflakeConnection placeholder 412 var conn interface{} = sct.sc 413 qStatus, err := conn.(SnowflakeConnection).GetQueryStatus(sct.sc.ctx, qid) 414 if err != nil { 415 t.Errorf("failed to get query status err = %s", err.Error()) 416 return 417 } 418 if qStatus == nil { 419 t.Error("there was no query status returned") 420 return 421 } 422 if qStatus.ErrorCode != "" || qStatus.ScanBytes <= 0 || qStatus.ProducedRows != 10 { 423 t.Errorf("expected no error. got: %v, scan bytes: %v, produced rows: %v", 424 qStatus.ErrorCode, qStatus.ScanBytes, qStatus.ProducedRows) 425 return 426 } 427 }) 428 } 429 430 func TestGetInvalidQueryStatus(t *testing.T) { 431 runSnowflakeConnTest(t, func(sct *SCTest) { 432 sct.sc.rest.RequestTimeout = 1 * time.Second 433 434 qStatus, err := sct.sc.checkQueryStatus(sct.sc.ctx, "1234") 435 if err == nil || qStatus != nil { 436 t.Error("expected an error") 437 } 438 }) 439 } 440 441 func TestExecWithServerSideError(t *testing.T) { 442 postQueryMock := func(_ context.Context, _ *snowflakeRestful, 443 _ *url.Values, _ map[string]string, _ []byte, _ time.Duration, 444 requestID UUID, _ *Config) (*execResponse, error) { 445 dd := &execResponseData{} 446 return &execResponse{ 447 Data: *dd, 448 Message: "", 449 Code: "", 450 Success: false, 451 }, nil 452 } 453 454 sr := &snowflakeRestful{ 455 FuncPostQuery: postQueryMock, 456 } 457 sc := &snowflakeConn{ 458 cfg: &Config{Params: map[string]*string{}}, 459 rest: sr, 460 telemetry: testTelemetry, 461 } 462 _, err := sc.exec(context.Background(), "", false, /* noResult */ 463 false /* isInternal */, false /* describeOnly */, nil) 464 if err == nil { 465 t.Error("expected a server side error") 466 } 467 sfe := err.(*SnowflakeError) 468 errUnknownError := errUnknownError() 469 if sfe.Number != -1 || sfe.SQLState != "-1" || sfe.QueryID != "-1" { 470 t.Errorf("incorrect snowflake error. expected: %v, got: %v", errUnknownError, *sfe) 471 } 472 if !strings.Contains(sfe.Message, "an unknown server side error occurred") { 473 t.Errorf("incorrect message. expected: %v, got: %v", errUnknownError.Message, sfe.Message) 474 } 475 } 476 477 func TestConcurrentReadOnParams(t *testing.T) { 478 config, err := ParseDSN(dsn) 479 if err != nil { 480 t.Fatal("Failed to parse dsn") 481 } 482 connector := NewConnector(SnowflakeDriver{}, *config) 483 db := sql.OpenDB(connector) 484 defer db.Close() 485 wg := sync.WaitGroup{} 486 for i := 0; i < 10; i++ { 487 wg.Add(1) 488 go func() { 489 for c := 0; c < 10; c++ { 490 stmt, err := db.PrepareContext(context.Background(), "SELECT table_schema FROM information_schema.columns WHERE table_schema = ? LIMIT 1") 491 if err != nil { 492 t.Error(err) 493 } 494 rows, err := stmt.Query("INFORMATION_SCHEMA") 495 if err != nil { 496 t.Error(err) 497 } 498 if rows == nil { 499 continue 500 } 501 rows.Next() 502 var tableName string 503 err = rows.Scan(&tableName) 504 if err != nil { 505 t.Error(err) 506 } 507 _ = rows.Close() 508 } 509 wg.Done() 510 }() 511 } 512 wg.Wait() 513 } 514 515 func postQueryTest(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { 516 return nil, errors.New("failed to get query response") 517 } 518 519 func postQueryFail(_ context.Context, _ *snowflakeRestful, _ *url.Values, headers map[string]string, _ []byte, _ time.Duration, _ UUID, _ *Config) (*execResponse, error) { 520 dd := &execResponseData{ 521 QueryID: "1eFhmhe23242kmfd540GgGre", 522 SQLState: "22008", 523 } 524 return &execResponse{ 525 Data: *dd, 526 Message: "failed to get query response", 527 Code: "12345", 528 Success: false, 529 }, errors.New("failed to get query response") 530 } 531 532 func TestErrorReportingOnConcurrentFails(t *testing.T) { 533 db := openDB(t) 534 defer db.Close() 535 var wg sync.WaitGroup 536 n := 5 537 wg.Add(3 * n) 538 for i := 0; i < n; i++ { 539 go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_ABC", "TABLE_ABC", t, &wg) 540 go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_DEF", "TABLE_DEF", t, &wg) 541 go executeQueryAndConfirmMessage(db, "SELECT * FROM TABLE_GHI", "TABLE_GHI", t, &wg) 542 } 543 wg.Wait() 544 } 545 546 func executeQueryAndConfirmMessage(db *sql.DB, query string, expectedErrorTable string, t *testing.T, wg *sync.WaitGroup) { 547 defer wg.Done() 548 _, err := db.Exec(query) 549 message := err.(*SnowflakeError).Message 550 if !strings.Contains(message, expectedErrorTable) { 551 t.Errorf("QueryID: %s, Message %s ###### Expected error message table name: %s", 552 err.(*SnowflakeError).QueryID, err.(*SnowflakeError).Message, expectedErrorTable) 553 } 554 } 555 556 func TestQueryArrowStreamError(t *testing.T) { 557 runSnowflakeConnTest(t, func(sct *SCTest) { 558 numrows := 50000 // approximately 10 ArrowBatch objects 559 query := fmt.Sprintf(selectRandomGenerator, numrows) 560 sct.sc.rest = &snowflakeRestful{ 561 FuncPostQuery: postQueryTest, 562 FuncCloseSession: closeSessionMock, 563 TokenAccessor: getSimpleTokenAccessor(), 564 RequestTimeout: 10, 565 } 566 _, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) 567 if err == nil { 568 t.Error("should have raised an error") 569 } 570 571 sct.sc.rest.FuncPostQuery = postQueryFail 572 _, err = sct.sc.QueryArrowStream(sct.sc.ctx, query) 573 if err == nil { 574 t.Error("should have raised an error") 575 } 576 _, ok := err.(*SnowflakeError) 577 if !ok { 578 t.Fatalf("should be snowflake error. err: %v", err) 579 } 580 }) 581 } 582 583 func TestExecContextError(t *testing.T) { 584 runSnowflakeConnTest(t, func(sct *SCTest) { 585 sct.sc.rest = &snowflakeRestful{ 586 FuncPostQuery: postQueryTest, 587 FuncCloseSession: closeSessionMock, 588 TokenAccessor: getSimpleTokenAccessor(), 589 RequestTimeout: 10, 590 } 591 592 _, err := sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) 593 if err == nil { 594 t.Fatalf("should have raised an error") 595 } 596 597 sct.sc.rest.FuncPostQuery = postQueryFail 598 _, err = sct.sc.ExecContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) 599 if err == nil { 600 t.Fatalf("should have raised an error") 601 } 602 }) 603 } 604 605 func TestQueryContextError(t *testing.T) { 606 runSnowflakeConnTest(t, func(sct *SCTest) { 607 sct.sc.rest = &snowflakeRestful{ 608 FuncPostQuery: postQueryTest, 609 FuncCloseSession: closeSessionMock, 610 TokenAccessor: getSimpleTokenAccessor(), 611 RequestTimeout: 10, 612 } 613 _, err := sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) 614 if err == nil { 615 t.Fatalf("should have raised an error") 616 } 617 618 sct.sc.rest.FuncPostQuery = postQueryFail 619 _, err = sct.sc.QueryContext(sct.sc.ctx, "SELECT 1", []driver.NamedValue{}) 620 if err == nil { 621 t.Fatalf("should have raised an error") 622 } 623 _, ok := err.(*SnowflakeError) 624 if !ok { 625 t.Fatalf("should be snowflake error. err: %v", err) 626 } 627 }) 628 } 629 630 func TestPrepareQuery(t *testing.T) { 631 runSnowflakeConnTest(t, func(sct *SCTest) { 632 _, err := sct.sc.Prepare("SELECT 1") 633 634 if err != nil { 635 t.Fatalf("failed to prepare query. err: %v", err) 636 } 637 }) 638 } 639 640 func TestBeginCreatesTransaction(t *testing.T) { 641 runSnowflakeConnTest(t, func(sct *SCTest) { 642 tx, _ := sct.sc.Begin() 643 if tx == nil { 644 t.Fatal("should have created a transaction with connection") 645 } 646 }) 647 }