github.com/snowflakedb/gosnowflake@v1.9.0/driver_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "crypto/rsa" 8 "database/sql" 9 "database/sql/driver" 10 "flag" 11 "fmt" 12 "net/http" 13 "net/url" 14 "os" 15 "os/signal" 16 "strings" 17 "syscall" 18 "testing" 19 "time" 20 ) 21 22 var ( 23 username string 24 pass string 25 account string 26 dbname string 27 schemaname string 28 warehouse string 29 rolename string 30 dsn string 31 host string 32 port string 33 protocol string 34 customPrivateKey bool // Whether user has specified the private key path 35 testPrivKey *rsa.PrivateKey // Valid private key used for all test cases 36 ) 37 38 const ( 39 selectNumberSQL = "SELECT %s::NUMBER(%v, %v) AS C" 40 selectVariousTypes = "SELECT 1.0::NUMBER(30,2) as C1, 2::NUMBER(38,0) AS C2, 't3' AS C3, 4.2::DOUBLE AS C4, 'abcd'::BINARY AS C5, true AS C6" 41 selectRandomGenerator = "SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))" 42 PSTLocation = "America/Los_Angeles" 43 ) 44 45 // The tests require the following parameters in the environment variables. 46 // SNOWFLAKE_TEST_USER, SNOWFLAKE_TEST_PASSWORD, SNOWFLAKE_TEST_ACCOUNT, 47 // SNOWFLAKE_TEST_DATABASE, SNOWFLAKE_TEST_SCHEMA, SNOWFLAKE_TEST_WAREHOUSE. 48 // Optionally you may specify SNOWFLAKE_TEST_PROTOCOL, SNOWFLAKE_TEST_HOST 49 // and SNOWFLAKE_TEST_PORT to specify the endpoint. 50 func init() { 51 // get environment variables 52 env := func(key, defaultValue string) string { 53 if value := os.Getenv(key); value != "" { 54 return value 55 } 56 return defaultValue 57 } 58 username = env("SNOWFLAKE_TEST_USER", "testuser") 59 pass = env("SNOWFLAKE_TEST_PASSWORD", "testpassword") 60 account = env("SNOWFLAKE_TEST_ACCOUNT", "testaccount") 61 dbname = env("SNOWFLAKE_TEST_DATABASE", "testdb") 62 schemaname = env("SNOWFLAKE_TEST_SCHEMA", "public") 63 rolename = env("SNOWFLAKE_TEST_ROLE", "sysadmin") 64 warehouse = env("SNOWFLAKE_TEST_WAREHOUSE", "testwarehouse") 65 66 protocol = env("SNOWFLAKE_TEST_PROTOCOL", "https") 67 host = os.Getenv("SNOWFLAKE_TEST_HOST") 68 port = env("SNOWFLAKE_TEST_PORT", "443") 69 if host == "" { 70 host = fmt.Sprintf("%s.snowflakecomputing.com", account) 71 } else { 72 host = fmt.Sprintf("%s:%s", host, port) 73 } 74 75 setupPrivateKey() 76 77 createDSN("UTC") 78 } 79 80 func createDSN(timezone string) { 81 dsn = fmt.Sprintf("%s:%s@%s/%s/%s", username, pass, host, dbname, schemaname) 82 83 parameters := url.Values{} 84 parameters.Add("timezone", timezone) 85 if protocol != "" { 86 parameters.Add("protocol", protocol) 87 } 88 if account != "" { 89 parameters.Add("account", account) 90 } 91 if warehouse != "" { 92 parameters.Add("warehouse", warehouse) 93 } 94 if rolename != "" { 95 parameters.Add("role", rolename) 96 } 97 98 if len(parameters) > 0 { 99 dsn += "?" + parameters.Encode() 100 } 101 } 102 103 // setup creates a test schema so that all tests can run in the same schema 104 func setup() (string, error) { 105 env := func(key, defaultValue string) string { 106 if value := os.Getenv(key); value != "" { 107 return value 108 } 109 return defaultValue 110 } 111 112 orgSchemaname := schemaname 113 if env("GITHUB_WORKFLOW", "") != "" { 114 githubRunnerID := env("RUNNER_TRACKING_ID", "GITHUB_RUNNER_ID") 115 githubRunnerID = strings.ReplaceAll(githubRunnerID, "-", "_") 116 githubSha := env("GITHUB_SHA", "GITHUB_SHA") 117 schemaname = fmt.Sprintf("%v_%v", githubRunnerID, githubSha) 118 } else { 119 schemaname = fmt.Sprintf("golang_%v", time.Now().UnixNano()) 120 } 121 var db *sql.DB 122 var err error 123 if db, err = sql.Open("snowflake", dsn); err != nil { 124 return "", fmt.Errorf("failed to open db. err: %v", err) 125 } 126 defer db.Close() 127 if _, err = db.Exec(fmt.Sprintf("CREATE OR REPLACE SCHEMA %v", schemaname)); err != nil { 128 return "", fmt.Errorf("failed to create schema. %v", err) 129 } 130 createDSN("UTC") 131 return orgSchemaname, nil 132 } 133 134 // teardown drops the test schema 135 func teardown() error { 136 var db *sql.DB 137 var err error 138 if db, err = sql.Open("snowflake", dsn); err != nil { 139 return fmt.Errorf("failed to open db. %v, err: %v", dsn, err) 140 } 141 defer db.Close() 142 if _, err = db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %v", schemaname)); err != nil { 143 return fmt.Errorf("failed to create schema. %v", err) 144 } 145 return nil 146 } 147 148 func TestMain(m *testing.M) { 149 flag.Parse() 150 signal.Ignore(syscall.SIGQUIT) 151 if value := os.Getenv("SKIP_SETUP"); value != "" { 152 os.Exit(m.Run()) 153 } 154 155 if _, err := setup(); err != nil { 156 panic(err) 157 } 158 ret := m.Run() 159 if err := teardown(); err != nil { 160 panic(err) 161 } 162 os.Exit(ret) 163 } 164 165 type DBTest struct { 166 *testing.T 167 conn *sql.Conn 168 } 169 170 func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) { 171 // handler interrupt signal 172 ctx, cancel := context.WithCancel(context.Background()) 173 c := make(chan os.Signal, 1) 174 c0 := make(chan bool, 1) 175 signal.Notify(c, os.Interrupt) 176 defer func() { 177 signal.Stop(c) 178 }() 179 go func() { 180 select { 181 case <-c: 182 fmt.Println("Caught signal, canceling...") 183 cancel() 184 case <-ctx.Done(): 185 fmt.Println("Done") 186 case <-c0: 187 } 188 close(c) 189 }() 190 191 rs, err := dbt.conn.QueryContext(ctx, query, args...) 192 if err != nil { 193 dbt.fail("query", query, err) 194 } 195 return &RowsExtended{ 196 rows: rs, 197 closeChan: &c0, 198 } 199 } 200 201 func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...interface{}) (rows *RowsExtended) { 202 // handler interrupt signal 203 ctx, cancel := context.WithCancel(ctx) 204 c := make(chan os.Signal, 1) 205 c0 := make(chan bool, 1) 206 signal.Notify(c, os.Interrupt) 207 defer func() { 208 signal.Stop(c) 209 }() 210 go func() { 211 select { 212 case <-c: 213 fmt.Println("Caught signal, canceling...") 214 cancel() 215 case <-ctx.Done(): 216 fmt.Println("Done") 217 case <-c0: 218 } 219 close(c) 220 }() 221 222 rs, err := dbt.conn.QueryContext(ctx, query, args...) 223 if err != nil { 224 dbt.fail("query", query, err) 225 } 226 return &RowsExtended{ 227 rows: rs, 228 closeChan: &c0, 229 } 230 } 231 232 func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) { 233 return dbt.conn.QueryContext(context.Background(), query, args...) 234 } 235 236 func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) { 237 rows := dbt.mustQuery(query, args...) 238 defer rows.Close() 239 cnt := 0 240 for rows.Next() { 241 cnt++ 242 } 243 if cnt != expected { 244 dbt.Fatalf("expected %v, got %v", expected, cnt) 245 } 246 } 247 248 func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) { 249 return dbt.conn.PrepareContext(context.Background(), query) 250 } 251 252 func (dbt *DBTest) fail(method, query string, err error) { 253 if len(query) > 300 { 254 query = "[query too large to print]" 255 } 256 dbt.Fatalf("error on %s [%s]: %s", method, query, err.Error()) 257 } 258 259 func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { 260 return dbt.mustExecContext(context.Background(), query, args...) 261 } 262 263 func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) { 264 res, err := dbt.conn.ExecContext(ctx, query, args...) 265 if err != nil { 266 dbt.fail("exec context", query, err) 267 } 268 return res 269 } 270 271 func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) { 272 return dbt.conn.ExecContext(context.Background(), query, args...) 273 } 274 275 func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) { 276 var ok bool 277 pr, sc, ok = ct.DecimalSize() 278 if !ok { 279 dbt.Fatalf("failed to get decimal size. %v", ct) 280 } 281 return pr, sc 282 } 283 284 func (dbt *DBTest) mustFailDecimalSize(ct *sql.ColumnType) { 285 var ok bool 286 if _, _, ok = ct.DecimalSize(); ok { 287 dbt.Fatalf("should not return decimal size. %v", ct) 288 } 289 } 290 291 func (dbt *DBTest) mustLength(ct *sql.ColumnType) (cLen int64) { 292 var ok bool 293 cLen, ok = ct.Length() 294 if !ok { 295 dbt.Fatalf("failed to get length. %v", ct) 296 } 297 return cLen 298 } 299 300 func (dbt *DBTest) mustFailLength(ct *sql.ColumnType) { 301 var ok bool 302 if _, ok = ct.Length(); ok { 303 dbt.Fatalf("should not return length. %v", ct) 304 } 305 } 306 307 func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) { 308 var ok bool 309 canNull, ok = ct.Nullable() 310 if !ok { 311 dbt.Fatalf("failed to get length. %v", ct) 312 } 313 return canNull 314 } 315 316 func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) { 317 stmt, err := dbt.conn.PrepareContext(context.Background(), query) 318 if err != nil { 319 dbt.fail("prepare", query, err) 320 } 321 return stmt 322 } 323 324 type SCTest struct { 325 *testing.T 326 sc *snowflakeConn 327 } 328 329 func (sct *SCTest) fail(method, query string, err error) { 330 if len(query) > 300 { 331 query = "[query too large to print]" 332 } 333 sct.Fatalf("error on %s [%s]: %s", method, query, err.Error()) 334 } 335 336 func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result { 337 result, err := sct.sc.Exec(query, args) 338 if err != nil { 339 sct.fail("exec", query, err) 340 } 341 return result 342 } 343 func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows { 344 rows, err := sct.sc.Query(query, args) 345 if err != nil { 346 sct.fail("query", query, err) 347 } 348 return rows 349 } 350 351 func (sct *SCTest) mustQueryContext(ctx context.Context, query string, args []driver.NamedValue) driver.Rows { 352 rows, err := sct.sc.QueryContext(ctx, query, args) 353 if err != nil { 354 sct.fail("QueryContext", query, err) 355 } 356 return rows 357 } 358 359 func (sct *SCTest) mustExecContext(ctx context.Context, query string, args []driver.NamedValue) driver.Result { 360 result, err := sct.sc.ExecContext(ctx, query, args) 361 if err != nil { 362 sct.fail("ExecContext", query, err) 363 } 364 return result 365 } 366 367 func runDBTest(t *testing.T, test func(dbt *DBTest)) { 368 conn := openConn(t) 369 defer conn.Close() 370 dbt := &DBTest{t, conn} 371 372 test(dbt) 373 } 374 375 func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) { 376 config, err := ParseDSN(dsn) 377 if err != nil { 378 t.Error(err) 379 } 380 sc, err := buildSnowflakeConn(context.Background(), *config) 381 if err != nil { 382 t.Fatal(err) 383 } 384 defer sc.Close() 385 if err = authenticateWithConfig(sc); err != nil { 386 t.Fatal(err) 387 } 388 389 sct := &SCTest{t, sc} 390 391 test(sct) 392 } 393 394 func runningOnAWS() bool { 395 return os.Getenv("CLOUD_PROVIDER") == "AWS" 396 } 397 398 func runningOnGCP() bool { 399 return os.Getenv("CLOUD_PROVIDER") == "GCP" 400 } 401 402 func TestBogusUserPasswordParameters(t *testing.T) { 403 invalidDNS := fmt.Sprintf("%s:%s@%s", "bogus", pass, host) 404 invalidUserPassErrorTests(invalidDNS, t) 405 invalidDNS = fmt.Sprintf("%s:%s@%s", username, "INVALID_PASSWORD", host) 406 invalidUserPassErrorTests(invalidDNS, t) 407 } 408 409 func invalidUserPassErrorTests(invalidDNS string, t *testing.T) { 410 parameters := url.Values{} 411 if protocol != "" { 412 parameters.Add("protocol", protocol) 413 } 414 if account != "" { 415 parameters.Add("account", account) 416 } 417 invalidDNS += "?" + parameters.Encode() 418 db, err := sql.Open("snowflake", invalidDNS) 419 if err != nil { 420 t.Fatalf("error creating a connection object: %s", err.Error()) 421 } 422 // actual connection won't happen until run a query 423 defer db.Close() 424 if _, err = db.Exec("SELECT 1"); err == nil { 425 t.Fatal("should cause an error.") 426 } 427 if driverErr, ok := err.(*SnowflakeError); ok { 428 if driverErr.Number != 390100 { 429 t.Fatalf("wrong error code: %v", driverErr) 430 } 431 if !strings.Contains(driverErr.Error(), "390100") { 432 t.Fatalf("error message should included the error code. got: %v", driverErr.Error()) 433 } 434 } else { 435 t.Fatalf("wrong error code: %v", err) 436 } 437 } 438 439 func TestBogusHostNameParameters(t *testing.T) { 440 invalidDNS := fmt.Sprintf("%s:%s@%s", username, pass, "INVALID_HOST:1234") 441 invalidHostErrorTests(invalidDNS, []string{"no such host", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution", "server misbehaving"}, t) 442 invalidDNS = fmt.Sprintf("%s:%s@%s", username, pass, "INVALID_HOST") 443 invalidHostErrorTests(invalidDNS, []string{"read: connection reset by peer", "EOF", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution", "server misbehaving"}, t) 444 } 445 446 func invalidHostErrorTests(invalidDNS string, mstr []string, t *testing.T) { 447 parameters := url.Values{} 448 if protocol != "" { 449 parameters.Add("protocol", protocol) 450 } 451 if account != "" { 452 parameters.Add("account", account) 453 } 454 parameters.Add("loginTimeout", "10") 455 invalidDNS += "?" + parameters.Encode() 456 db, err := sql.Open("snowflake", invalidDNS) 457 if err != nil { 458 t.Fatalf("error creating a connection object: %s", err.Error()) 459 } 460 // actual connection won't happen until run a query 461 defer db.Close() 462 if _, err = db.Exec("SELECT 1"); err == nil { 463 t.Fatal("should cause an error.") 464 } 465 found := false 466 for _, m := range mstr { 467 if strings.Contains(err.Error(), m) { 468 found = true 469 } 470 } 471 if !found { 472 t.Fatalf("wrong error: %v", err) 473 } 474 } 475 476 func TestCommentOnlyQuery(t *testing.T) { 477 runDBTest(t, func(dbt *DBTest) { 478 query := "--" 479 // just a comment, no query 480 rows, err := dbt.query(query) 481 if err == nil { 482 rows.Close() 483 dbt.fail("query", query, err) 484 } 485 if driverErr, ok := err.(*SnowflakeError); ok { 486 if driverErr.Number != 900 { // syntax error 487 dbt.fail("query", query, err) 488 } 489 } 490 }) 491 } 492 493 func TestEmptyQuery(t *testing.T) { 494 runDBTest(t, func(dbt *DBTest) { 495 query := "select 1 from dual where 1=0" 496 // just a comment, no query 497 rows := dbt.conn.QueryRowContext(context.Background(), query) 498 var v1 any 499 if err := rows.Scan(&v1); err != sql.ErrNoRows { 500 dbt.Errorf("should fail. err: %v", err) 501 } 502 rows = dbt.conn.QueryRowContext(context.Background(), query) 503 if err := rows.Scan(&v1); err != sql.ErrNoRows { 504 dbt.Errorf("should fail. err: %v", err) 505 } 506 }) 507 } 508 509 func TestEmptyQueryWithRequestID(t *testing.T) { 510 runDBTest(t, func(dbt *DBTest) { 511 query := "select 1" 512 ctx := WithRequestID(context.Background(), NewUUID()) 513 rows := dbt.conn.QueryRowContext(ctx, query) 514 var v1 interface{} 515 if err := rows.Scan(&v1); err != nil { 516 dbt.Errorf("should not have failed with valid request id. err: %v", err) 517 } 518 }) 519 } 520 521 func TestCRUD(t *testing.T) { 522 runDBTest(t, func(dbt *DBTest) { 523 // Create Table 524 dbt.mustExec("CREATE OR REPLACE TABLE test (value BOOLEAN)") 525 526 // Test for unexpected Data 527 var out bool 528 rows := dbt.mustQuery("SELECT * FROM test") 529 defer rows.Close() 530 if rows.Next() { 531 dbt.Error("unexpected Data in empty table") 532 } 533 534 // Create Data 535 res := dbt.mustExec("INSERT INTO test VALUES (true)") 536 count, err := res.RowsAffected() 537 if err != nil { 538 dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 539 } 540 if count != 1 { 541 dbt.Fatalf("expected 1 affected row, got %d", count) 542 } 543 544 id, err := res.LastInsertId() 545 if err != nil { 546 dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) 547 } 548 if id != -1 { 549 dbt.Fatalf( 550 "expected InsertId -1, got %d. Snowflake doesn't support last insert ID", id) 551 } 552 553 // Read 554 rows = dbt.mustQuery("SELECT value FROM test") 555 defer rows.Close() 556 if rows.Next() { 557 rows.Scan(&out) 558 if !out { 559 dbt.Errorf("%t should be true", out) 560 } 561 if rows.Next() { 562 dbt.Error("unexpected Data") 563 } 564 } else { 565 dbt.Error("no Data") 566 } 567 568 // Update 569 res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) 570 count, err = res.RowsAffected() 571 if err != nil { 572 dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 573 } 574 if count != 1 { 575 dbt.Fatalf("expected 1 affected row, got %d", count) 576 } 577 578 // Check Update 579 rows = dbt.mustQuery("SELECT value FROM test") 580 defer rows.Close() 581 if rows.Next() { 582 rows.Scan(&out) 583 if out { 584 dbt.Errorf("%t should be true", out) 585 } 586 if rows.Next() { 587 dbt.Error("unexpected Data") 588 } 589 } else { 590 dbt.Error("no Data") 591 } 592 593 // Delete 594 res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) 595 count, err = res.RowsAffected() 596 if err != nil { 597 dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 598 } 599 if count != 1 { 600 dbt.Fatalf("expected 1 affected row, got %d", count) 601 } 602 603 // Check for unexpected rows 604 res = dbt.mustExec("DELETE FROM test") 605 count, err = res.RowsAffected() 606 if err != nil { 607 dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) 608 } 609 if count != 0 { 610 dbt.Fatalf("expected 0 affected row, got %d", count) 611 } 612 }) 613 } 614 615 func TestInt(t *testing.T) { 616 testInt(t, false) 617 } 618 619 func testInt(t *testing.T, json bool) { 620 runDBTest(t, func(dbt *DBTest) { 621 types := []string{"INT", "INTEGER"} 622 in := int64(42) 623 var out int64 624 var rows *RowsExtended 625 626 // SIGNED 627 for _, v := range types { 628 t.Run(v, func(t *testing.T) { 629 if json { 630 dbt.mustExec(forceJSON) 631 } 632 dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") 633 dbt.mustExec("INSERT INTO test VALUES (?)", in) 634 rows = dbt.mustQuery("SELECT value FROM test") 635 defer rows.Close() 636 if rows.Next() { 637 rows.Scan(&out) 638 if in != out { 639 dbt.Errorf("%s: %d != %d", v, in, out) 640 } 641 } else { 642 dbt.Errorf("%s: no data", v) 643 } 644 645 }) 646 } 647 dbt.mustExec("DROP TABLE IF EXISTS test") 648 }) 649 } 650 651 func TestFloat32(t *testing.T) { 652 testFloat32(t, false) 653 } 654 655 func testFloat32(t *testing.T, json bool) { 656 runDBTest(t, func(dbt *DBTest) { 657 types := [2]string{"FLOAT", "DOUBLE"} 658 in := float32(42.23) 659 var out float32 660 var rows *RowsExtended 661 for _, v := range types { 662 t.Run(v, func(t *testing.T) { 663 if json { 664 dbt.mustExec(forceJSON) 665 } 666 dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") 667 dbt.mustExec("INSERT INTO test VALUES (?)", in) 668 rows = dbt.mustQuery("SELECT value FROM test") 669 defer rows.Close() 670 if rows.Next() { 671 err := rows.Scan(&out) 672 if err != nil { 673 dbt.Errorf("failed to scan data: %v", err) 674 } 675 if in != out { 676 dbt.Errorf("%s: %g != %g", v, in, out) 677 } 678 } else { 679 dbt.Errorf("%s: no data", v) 680 } 681 }) 682 } 683 dbt.mustExec("DROP TABLE IF EXISTS test") 684 }) 685 } 686 687 func TestFloat64(t *testing.T) { 688 testFloat64(t, false) 689 } 690 691 func testFloat64(t *testing.T, json bool) { 692 runDBTest(t, func(dbt *DBTest) { 693 types := [2]string{"FLOAT", "DOUBLE"} 694 expected := 42.23 695 var out float64 696 var rows *RowsExtended 697 for _, v := range types { 698 t.Run(v, func(t *testing.T) { 699 if json { 700 dbt.mustExec(forceJSON) 701 } 702 dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") 703 dbt.mustExec("INSERT INTO test VALUES (42.23)") 704 rows = dbt.mustQuery("SELECT value FROM test") 705 defer rows.Close() 706 if rows.Next() { 707 rows.Scan(&out) 708 if expected != out { 709 dbt.Errorf("%s: %g != %g", v, expected, out) 710 } 711 } else { 712 dbt.Errorf("%s: no data", v) 713 } 714 }) 715 } 716 dbt.mustExec("DROP TABLE IF EXISTS test") 717 }) 718 } 719 720 func TestString(t *testing.T) { 721 testString(t, false) 722 } 723 724 func testString(t *testing.T, json bool) { 725 runDBTest(t, func(dbt *DBTest) { 726 if json { 727 dbt.mustExec(forceJSON) 728 } 729 types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"} 730 in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" 731 var out string 732 var rows *RowsExtended 733 734 for _, v := range types { 735 t.Run(v, func(t *testing.T) { 736 dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")") 737 dbt.mustExec("INSERT INTO test VALUES (?)", in) 738 739 rows = dbt.mustQuery("SELECT value FROM test") 740 defer rows.Close() 741 if rows.Next() { 742 rows.Scan(&out) 743 if in != out { 744 dbt.Errorf("%s: %s != %s", v, in, out) 745 } 746 } else { 747 dbt.Errorf("%s: no data", v) 748 } 749 }) 750 } 751 dbt.mustExec("DROP TABLE IF EXISTS test") 752 753 // BLOB (Snowflake doesn't support BLOB type but STRING covers large text data) 754 dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value STRING)") 755 756 id := 2 757 in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam 758 nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam 759 erat, sed diam voluptua. At vero eos et accusam et justo duo 760 dolores et ea rebum. Stet clita kasd gubergren, no sea takimata 761 sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, 762 consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt 763 ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero 764 eos et accusam et justo duo dolores et ea rebum. Stet clita kasd 765 gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.` 766 dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) 767 768 if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil { 769 dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) 770 } else if out != in { 771 dbt.Errorf("BLOB: %s != %s", in, out) 772 } 773 }) 774 } 775 776 type tcDateTimeTimestamp struct { 777 dbtype string 778 tlayout string 779 tests []timeTest 780 } 781 782 type timeTest struct { 783 s string // source date time string 784 t time.Time // expected fetched data 785 } 786 787 func (tt timeTest) genQuery() string { 788 return "SELECT '%s'::%s" 789 } 790 791 func (tt timeTest) run(t *testing.T, dbt *DBTest, dbtype, tlayout string) { 792 var rows *RowsExtended 793 query := fmt.Sprintf(tt.genQuery(), tt.s, dbtype) 794 rows = dbt.mustQuery(query) 795 defer rows.Close() 796 var err error 797 if !rows.Next() { 798 err = rows.Err() 799 if err == nil { 800 err = fmt.Errorf("no data") 801 } 802 dbt.Errorf("%s: %s", dbtype, err) 803 return 804 } 805 806 var dst interface{} 807 if err = rows.Scan(&dst); err != nil { 808 dbt.Errorf("%s: %s", dbtype, err) 809 return 810 } 811 switch val := dst.(type) { 812 case []uint8: 813 str := string(val) 814 if str == tt.s { 815 return 816 } 817 dbt.Errorf("%s to string: expected %q, got %q", 818 dbtype, 819 tt.s, 820 str, 821 ) 822 case time.Time: 823 if val.UnixNano() == tt.t.UnixNano() { 824 return 825 } 826 t.Logf("source:%v, expected: %v, got:%v", tt.s, tt.t, val) 827 dbt.Errorf("%s to string: expected %q, got %q", 828 dbtype, 829 tt.s, 830 val.Format(tlayout), 831 ) 832 default: 833 dbt.Errorf("%s: unhandled type %T (is '%v')", 834 dbtype, val, val, 835 ) 836 } 837 } 838 839 func TestSimpleDateTimeTimestampFetch(t *testing.T) { 840 testSimpleDateTimeTimestampFetch(t, false) 841 } 842 843 func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) { 844 var scan = func(rows *RowsExtended, cd interface{}, ct interface{}, cts interface{}) { 845 if err := rows.Scan(cd, ct, cts); err != nil { 846 t.Fatal(err) 847 } 848 } 849 var fetchTypes = []func(*RowsExtended){ 850 func(rows *RowsExtended) { 851 var cd, ct, cts time.Time 852 scan(rows, &cd, &ct, &cts) 853 }, 854 func(rows *RowsExtended) { 855 var cd, ct, cts time.Time 856 scan(rows, &cd, &ct, &cts) 857 }, 858 } 859 runDBTest(t, func(dbt *DBTest) { 860 if json { 861 dbt.mustExec(forceJSON) 862 } 863 for _, f := range fetchTypes { 864 rows := dbt.mustQuery("SELECT CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIMESTAMP()") 865 defer rows.Close() 866 if rows.Next() { 867 f(rows) 868 } else { 869 t.Fatal("no results") 870 } 871 } 872 }) 873 } 874 875 func TestDateTime(t *testing.T) { 876 testDateTime(t, false) 877 } 878 879 func testDateTime(t *testing.T, json bool) { 880 afterTime := func(t time.Time, d string) time.Time { 881 dur, err := time.ParseDuration(d) 882 if err != nil { 883 panic(err) 884 } 885 return t.Add(dur) 886 } 887 t0 := time.Time{} 888 tstr0 := "0000-00-00 00:00:00.000000000" 889 testcases := []tcDateTimeTimestamp{ 890 {"DATE", format[:10], []timeTest{ 891 {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, 892 {t: time.Date(2, 8, 2, 0, 0, 0, 0, time.UTC), s: "0002-08-02"}, 893 }}, 894 {"TIME", format[11:19], []timeTest{ 895 {t: afterTime(t0, "12345s")}, 896 {t: t0, s: tstr0[11:19]}, 897 }}, 898 {"TIME(0)", format[11:19], []timeTest{ 899 {t: afterTime(t0, "12345s")}, 900 {t: t0, s: tstr0[11:19]}, 901 }}, 902 {"TIME(1)", format[11:21], []timeTest{ 903 {t: afterTime(t0, "12345600ms")}, 904 {t: t0, s: tstr0[11:21]}, 905 }}, 906 {"TIME(6)", format[11:], []timeTest{ 907 {t: t0, s: tstr0[11:]}, 908 }}, 909 {"DATETIME", format[:19], []timeTest{ 910 {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, 911 }}, 912 {"DATETIME(0)", format[:21], []timeTest{ 913 {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, 914 }}, 915 {"DATETIME(1)", format[:21], []timeTest{ 916 {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, 917 }}, 918 {"DATETIME(6)", format, []timeTest{ 919 {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, 920 }}, 921 {"DATETIME(9)", format, []timeTest{ 922 {t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)}, 923 }}, 924 } 925 runDBTest(t, func(dbt *DBTest) { 926 if json { 927 dbt.mustExec(forceJSON) 928 } 929 for _, setups := range testcases { 930 t.Run(setups.dbtype, func(t *testing.T) { 931 for _, setup := range setups.tests { 932 if setup.s == "" { 933 // fill time string wherever Go can reliable produce it 934 setup.s = setup.t.Format(setups.tlayout) 935 } 936 setup.run(t, dbt, setups.dbtype, setups.tlayout) 937 } 938 }) 939 } 940 }) 941 } 942 943 func TestTimestampLTZ(t *testing.T) { 944 testTimestampLTZ(t, false) 945 } 946 947 func testTimestampLTZ(t *testing.T, json bool) { 948 // Set session time zone in Los Angeles, same as machine 949 createDSN(PSTLocation) 950 location, err := time.LoadLocation(PSTLocation) 951 if err != nil { 952 t.Error(err) 953 } 954 testcases := []tcDateTimeTimestamp{ 955 { 956 dbtype: "TIMESTAMP_LTZ(9)", 957 tlayout: format, 958 tests: []timeTest{ 959 { 960 s: "2016-12-30 05:02:03", 961 t: time.Date(2016, 12, 30, 5, 2, 3, 0, location), 962 }, 963 { 964 s: "2016-12-30 05:02:03 -00:00", 965 t: time.Date(2016, 12, 30, 5, 2, 3, 0, time.UTC), 966 }, 967 { 968 s: "2017-05-12 00:51:42", 969 t: time.Date(2017, 5, 12, 0, 51, 42, 0, location), 970 }, 971 { 972 s: "2017-03-12 01:00:00", 973 t: time.Date(2017, 3, 12, 1, 0, 0, 0, location), 974 }, 975 { 976 s: "2017-03-13 04:00:00", 977 t: time.Date(2017, 3, 13, 4, 0, 0, 0, location), 978 }, 979 { 980 s: "2017-03-13 04:00:00.123456789", 981 t: time.Date(2017, 3, 13, 4, 0, 0, 123456789, location), 982 }, 983 }, 984 }, 985 { 986 dbtype: "TIMESTAMP_LTZ(8)", 987 tlayout: format, 988 tests: []timeTest{ 989 { 990 s: "2017-03-13 04:00:00.123456789", 991 t: time.Date(2017, 3, 13, 4, 0, 0, 123456780, location), 992 }, 993 }, 994 }, 995 } 996 runDBTest(t, func(dbt *DBTest) { 997 if json { 998 dbt.mustExec(forceJSON) 999 } 1000 for _, setups := range testcases { 1001 t.Run(setups.dbtype, func(t *testing.T) { 1002 for _, setup := range setups.tests { 1003 if setup.s == "" { 1004 // fill time string wherever Go can reliable produce it 1005 setup.s = setup.t.Format(setups.tlayout) 1006 } 1007 setup.run(t, dbt, setups.dbtype, setups.tlayout) 1008 } 1009 }) 1010 } 1011 }) 1012 // Revert timezone to UTC, which is default for the test suit 1013 createDSN("UTC") 1014 } 1015 1016 func TestTimestampTZ(t *testing.T) { 1017 testTimestampTZ(t, false) 1018 } 1019 1020 func testTimestampTZ(t *testing.T, json bool) { 1021 sflo := func(offsets string) (loc *time.Location) { 1022 r, err := LocationWithOffsetString(offsets) 1023 if err != nil { 1024 return time.UTC 1025 } 1026 return r 1027 } 1028 testcases := []tcDateTimeTimestamp{ 1029 { 1030 dbtype: "TIMESTAMP_TZ(9)", 1031 tlayout: format, 1032 tests: []timeTest{ 1033 { 1034 s: "2016-12-30 05:02:03 +07:00", 1035 t: time.Date(2016, 12, 30, 5, 2, 3, 0, 1036 sflo("+0700")), 1037 }, 1038 { 1039 s: "2017-05-23 03:56:41 -09:00", 1040 t: time.Date(2017, 5, 23, 3, 56, 41, 0, 1041 sflo("-0900")), 1042 }, 1043 }, 1044 }, 1045 } 1046 runDBTest(t, func(dbt *DBTest) { 1047 if json { 1048 dbt.mustExec(forceJSON) 1049 } 1050 for _, setups := range testcases { 1051 t.Run(setups.dbtype, func(t *testing.T) { 1052 for _, setup := range setups.tests { 1053 if setup.s == "" { 1054 // fill time string wherever Go can reliable produce it 1055 setup.s = setup.t.Format(setups.tlayout) 1056 } 1057 setup.run(t, dbt, setups.dbtype, setups.tlayout) 1058 } 1059 }) 1060 } 1061 }) 1062 } 1063 1064 func TestNULL(t *testing.T) { 1065 testNULL(t, false) 1066 } 1067 1068 func testNULL(t *testing.T, json bool) { 1069 runDBTest(t, func(dbt *DBTest) { 1070 if json { 1071 dbt.mustExec(forceJSON) 1072 } 1073 nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL") 1074 if err != nil { 1075 dbt.Fatal(err) 1076 } 1077 defer nullStmt.Close() 1078 1079 nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1") 1080 if err != nil { 1081 dbt.Fatal(err) 1082 } 1083 defer nonNullStmt.Close() 1084 1085 // NullBool 1086 var nb sql.NullBool 1087 // Invalid 1088 if err = nullStmt.QueryRow().Scan(&nb); err != nil { 1089 dbt.Fatal(err) 1090 } 1091 if nb.Valid { 1092 dbt.Error("valid NullBool which should be invalid") 1093 } 1094 // Valid 1095 if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { 1096 dbt.Fatal(err) 1097 } 1098 if !nb.Valid { 1099 dbt.Error("invalid NullBool which should be valid") 1100 } else if !nb.Bool { 1101 dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) 1102 } 1103 1104 // NullFloat64 1105 var nf sql.NullFloat64 1106 // Invalid 1107 if err = nullStmt.QueryRow().Scan(&nf); err != nil { 1108 dbt.Fatal(err) 1109 } 1110 if nf.Valid { 1111 dbt.Error("valid NullFloat64 which should be invalid") 1112 } 1113 // Valid 1114 if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { 1115 dbt.Fatal(err) 1116 } 1117 if !nf.Valid { 1118 dbt.Error("invalid NullFloat64 which should be valid") 1119 } else if nf.Float64 != float64(1) { 1120 dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) 1121 } 1122 1123 // NullInt64 1124 var ni sql.NullInt64 1125 // Invalid 1126 if err = nullStmt.QueryRow().Scan(&ni); err != nil { 1127 dbt.Fatal(err) 1128 } 1129 if ni.Valid { 1130 dbt.Error("valid NullInt64 which should be invalid") 1131 } 1132 // Valid 1133 if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { 1134 dbt.Fatal(err) 1135 } 1136 if !ni.Valid { 1137 dbt.Error("invalid NullInt64 which should be valid") 1138 } else if ni.Int64 != int64(1) { 1139 dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) 1140 } 1141 1142 // NullString 1143 var ns sql.NullString 1144 // Invalid 1145 if err = nullStmt.QueryRow().Scan(&ns); err != nil { 1146 dbt.Fatal(err) 1147 } 1148 if ns.Valid { 1149 dbt.Error("valid NullString which should be invalid") 1150 } 1151 // Valid 1152 if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { 1153 dbt.Fatal(err) 1154 } 1155 if !ns.Valid { 1156 dbt.Error("invalid NullString which should be valid") 1157 } else if ns.String != `1` { 1158 dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") 1159 } 1160 1161 // nil-bytes 1162 var b []byte 1163 // Read nil 1164 if err = nullStmt.QueryRow().Scan(&b); err != nil { 1165 dbt.Fatal(err) 1166 } 1167 if b != nil { 1168 dbt.Error("non-nil []byte which should be nil") 1169 } 1170 // Read non-nil 1171 if err = nonNullStmt.QueryRow().Scan(&b); err != nil { 1172 dbt.Fatal(err) 1173 } 1174 if b == nil { 1175 dbt.Error("nil []byte which should be non-nil") 1176 } 1177 // Insert nil 1178 b = nil 1179 success := false 1180 if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil { 1181 dbt.Fatal(err) 1182 } 1183 if !success { 1184 dbt.Error("inserting []byte(nil) as NULL failed") 1185 t.Fatal("stopping") 1186 } 1187 // Check input==output with input==nil 1188 b = nil 1189 if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { 1190 dbt.Fatal(err) 1191 } 1192 if b != nil { 1193 dbt.Error("non-nil echo from nil input") 1194 } 1195 // Check input==output with input!=nil 1196 b = []byte("") 1197 if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil { 1198 dbt.Fatal(err) 1199 } 1200 if b == nil { 1201 dbt.Error("nil echo from non-nil input") 1202 } 1203 1204 // Insert NULL 1205 dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)") 1206 dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) 1207 1208 var out interface{} 1209 rows := dbt.mustQuery("SELECT * FROM test") 1210 defer rows.Close() 1211 if rows.Next() { 1212 rows.Scan(&out) 1213 if out != nil { 1214 dbt.Errorf("%v != nil", out) 1215 } 1216 } else { 1217 dbt.Error("no data") 1218 } 1219 }) 1220 } 1221 1222 func TestVariant(t *testing.T) { 1223 testVariant(t, false) 1224 } 1225 1226 func testVariant(t *testing.T, json bool) { 1227 runDBTest(t, func(dbt *DBTest) { 1228 if json { 1229 dbt.mustExec(forceJSON) 1230 } 1231 rows := dbt.mustQuery(`select parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]')`) 1232 defer rows.Close() 1233 var v string 1234 if rows.Next() { 1235 if err := rows.Scan(&v); err != nil { 1236 t.Fatal(err) 1237 } 1238 } else { 1239 t.Fatal("no rows") 1240 } 1241 }) 1242 } 1243 1244 func TestArray(t *testing.T) { 1245 testArray(t, false) 1246 } 1247 1248 func testArray(t *testing.T, json bool) { 1249 runDBTest(t, func(dbt *DBTest) { 1250 if json { 1251 dbt.mustExec(forceJSON) 1252 } 1253 rows := dbt.mustQuery(`select as_array(parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]'))`) 1254 defer rows.Close() 1255 var v string 1256 if rows.Next() { 1257 if err := rows.Scan(&v); err != nil { 1258 t.Fatal(err) 1259 } 1260 } else { 1261 t.Fatal("no rows") 1262 } 1263 }) 1264 } 1265 1266 func TestLargeSetResult(t *testing.T) { 1267 CustomJSONDecoderEnabled = false 1268 testLargeSetResult(t, 100000, false) 1269 } 1270 1271 func testLargeSetResult(t *testing.T, numrows int, json bool) { 1272 runDBTest(t, func(dbt *DBTest) { 1273 if json { 1274 dbt.mustExec(forceJSON) 1275 } 1276 rows := dbt.mustQuery(fmt.Sprintf(selectRandomGenerator, numrows)) 1277 defer rows.Close() 1278 cnt := 0 1279 var idx int 1280 var v string 1281 for rows.Next() { 1282 if err := rows.Scan(&idx, &v); err != nil { 1283 t.Fatal(err) 1284 } 1285 cnt++ 1286 } 1287 logger.Infof("NextResultSet: %v", rows.NextResultSet()) 1288 1289 if cnt != numrows { 1290 dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) 1291 } 1292 }) 1293 } 1294 1295 func TestPingpongQuery(t *testing.T) { 1296 runDBTest(t, func(dbt *DBTest) { 1297 numrows := 1 1298 rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))") 1299 defer rows.Close() 1300 cnt := 0 1301 for rows.Next() { 1302 cnt++ 1303 } 1304 if cnt != numrows { 1305 dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) 1306 } 1307 }) 1308 } 1309 1310 func TestDML(t *testing.T) { 1311 runDBTest(t, func(dbt *DBTest) { 1312 dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)") 1313 if err := insertData(dbt, false); err != nil { 1314 dbt.Fatalf("failed to insert data: %v", err) 1315 } 1316 results, err := queryTest(dbt) 1317 if err != nil { 1318 dbt.Fatalf("failed to query test table: %v", err) 1319 } 1320 if len(*results) != 0 { 1321 dbt.Fatalf("number of returned data didn't match. expected 0, got: %v", len(*results)) 1322 } 1323 if err = insertData(dbt, true); err != nil { 1324 dbt.Fatalf("failed to insert data: %v", err) 1325 } 1326 results, err = queryTest(dbt) 1327 if err != nil { 1328 dbt.Fatalf("failed to query test table: %v", err) 1329 } 1330 if len(*results) != 2 { 1331 dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results)) 1332 } 1333 }) 1334 } 1335 1336 func insertData(dbt *DBTest, commit bool) error { 1337 tx, err := dbt.conn.BeginTx(context.Background(), nil) 1338 if err != nil { 1339 dbt.Fatalf("failed to begin transaction: %v", err) 1340 } 1341 res, err := tx.Exec("INSERT INTO test VALUES(1, 'test1'), (2, 'test2')") 1342 if err != nil { 1343 dbt.Fatalf("failed to insert value into test: %v", err) 1344 } 1345 n, err := res.RowsAffected() 1346 if err != nil { 1347 dbt.Fatalf("failed to rows affected: %v", err) 1348 } 1349 if n != 2 { 1350 dbt.Fatalf("failed to insert value into test. expected: 2, got: %v", n) 1351 } 1352 results, err := queryTestTx(tx) 1353 if err != nil { 1354 dbt.Fatalf("failed to query test table: %v", err) 1355 } 1356 if len(*results) != 2 { 1357 dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results)) 1358 } 1359 if commit { 1360 if err = tx.Commit(); err != nil { 1361 return err 1362 } 1363 } else { 1364 if err = tx.Rollback(); err != nil { 1365 return err 1366 } 1367 } 1368 return err 1369 } 1370 1371 func queryTestTx(tx *sql.Tx) (*map[int]string, error) { 1372 var c1 int 1373 var c2 string 1374 rows, err := tx.Query("SELECT c1, c2 FROM test") 1375 if err != nil { 1376 return nil, err 1377 } 1378 defer rows.Close() 1379 1380 results := make(map[int]string, 2) 1381 for rows.Next() { 1382 if err = rows.Scan(&c1, &c2); err != nil { 1383 return nil, err 1384 } 1385 results[c1] = c2 1386 } 1387 return &results, nil 1388 } 1389 1390 func queryTest(dbt *DBTest) (*map[int]string, error) { 1391 var c1 int 1392 var c2 string 1393 rows, err := dbt.query("SELECT c1, c2 FROM test") 1394 if err != nil { 1395 return nil, err 1396 } 1397 defer rows.Close() 1398 results := make(map[int]string, 2) 1399 for rows.Next() { 1400 if err = rows.Scan(&c1, &c2); err != nil { 1401 return nil, err 1402 } 1403 results[c1] = c2 1404 } 1405 return &results, nil 1406 } 1407 1408 func TestCancelQuery(t *testing.T) { 1409 runDBTest(t, func(dbt *DBTest) { 1410 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 1411 defer cancel() 1412 1413 _, err := dbt.conn.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))") 1414 if err == nil { 1415 dbt.Fatal("No timeout error returned") 1416 } 1417 if err.Error() != "context deadline exceeded" { 1418 dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error()) 1419 } 1420 }) 1421 } 1422 1423 func TestPing(t *testing.T) { 1424 db := openConn(t) 1425 if err := db.PingContext(context.Background()); err != nil { 1426 t.Fatalf("failed to ping. err: %v", err) 1427 } 1428 if err := db.PingContext(context.Background()); err != nil { 1429 t.Fatalf("failed to ping with context. err: %v", err) 1430 } 1431 if err := db.Close(); err != nil { 1432 t.Fatalf("failed to close db. err: %v", err) 1433 } 1434 if err := db.PingContext(context.Background()); err == nil { 1435 t.Fatal("should have failed to ping") 1436 } 1437 if err := db.PingContext(context.Background()); err == nil { 1438 t.Fatal("should have failed to ping with context") 1439 } 1440 } 1441 1442 func TestDoubleDollar(t *testing.T) { 1443 // no escape is required for dollar signs 1444 runDBTest(t, func(dbt *DBTest) { 1445 sql := `create or replace function dateErr(I double) returns date 1446 language javascript strict 1447 as $$ 1448 var x = [ 1449 0, "1400000000000", 1450 "2013-04-05", 1451 [], [1400000000000], 1452 "x1234", 1453 Number.NaN, null, undefined, 1454 {}, 1455 [1400000000000,1500000000000] 1456 ]; 1457 return x[I]; 1458 $$ 1459 ;` 1460 dbt.mustExec(sql) 1461 }) 1462 } 1463 1464 func TestTimezoneSessionParameter(t *testing.T) { 1465 createDSN(PSTLocation) 1466 conn := openConn(t) 1467 defer conn.Close() 1468 1469 rows, err := conn.QueryContext(context.Background(), "SHOW PARAMETERS LIKE 'TIMEZONE'") 1470 if err != nil { 1471 t.Errorf("failed to run show parameters. err: %v", err) 1472 } 1473 defer rows.Close() 1474 if !rows.Next() { 1475 t.Fatal("failed to get timezone.") 1476 } 1477 1478 p, err := ScanSnowflakeParameter(rows) 1479 if err != nil { 1480 t.Errorf("failed to run get timezone value. err: %v", err) 1481 } 1482 if p.Value != PSTLocation { 1483 t.Errorf("failed to get an expected timezone. got: %v", p.Value) 1484 } 1485 createDSN("UTC") 1486 } 1487 1488 func TestLargeSetResultCancel(t *testing.T) { 1489 runDBTest(t, func(dbt *DBTest) { 1490 c := make(chan error) 1491 ctx, cancel := context.WithCancel(context.Background()) 1492 go func() { 1493 // attempt to run a 100 seconds query, but it should be canceled in 1 second 1494 timelimit := 100 1495 rows, err := dbt.conn.QueryContext( 1496 ctx, 1497 fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit)) 1498 if err != nil { 1499 c <- err 1500 return 1501 } 1502 defer rows.Close() 1503 c <- nil 1504 }() 1505 // cancel after 1 second 1506 time.Sleep(time.Second) 1507 cancel() 1508 ret := <-c 1509 if ret.Error() != "context canceled" { 1510 t.Fatalf("failed to cancel. err: %v", ret) 1511 } 1512 close(c) 1513 }) 1514 } 1515 1516 func TestValidateDatabaseParameter(t *testing.T) { 1517 baseDSN := fmt.Sprintf("%s:%s@%s", username, pass, host) 1518 testcases := []struct { 1519 dsn string 1520 params map[string]string 1521 errorCode int 1522 }{ 1523 { 1524 dsn: baseDSN + fmt.Sprintf("/%s/%s", "NOT_EXISTS", "NOT_EXISTS"), 1525 errorCode: ErrObjectNotExistOrAuthorized, 1526 }, 1527 { 1528 dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, "NOT_EXISTS"), 1529 errorCode: ErrObjectNotExistOrAuthorized, 1530 }, 1531 { 1532 dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname), 1533 params: map[string]string{ 1534 "warehouse": "NOT_EXIST", 1535 }, 1536 errorCode: ErrObjectNotExistOrAuthorized, 1537 }, 1538 { 1539 dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname), 1540 params: map[string]string{ 1541 "role": "NOT_EXIST", 1542 }, 1543 errorCode: ErrRoleNotExist, 1544 }, 1545 } 1546 for idx, tc := range testcases { 1547 t.Run(dsn, func(t *testing.T) { 1548 newDSN := tc.dsn 1549 parameters := url.Values{} 1550 if protocol != "" { 1551 parameters.Add("protocol", protocol) 1552 } 1553 if account != "" { 1554 parameters.Add("account", account) 1555 } 1556 for k, v := range tc.params { 1557 parameters.Add(k, v) 1558 } 1559 newDSN += "?" + parameters.Encode() 1560 db, err := sql.Open("snowflake", newDSN) 1561 // actual connection won't happen until run a query 1562 if err != nil { 1563 t.Fatalf("error creating a connection object: %s", err.Error()) 1564 } 1565 defer db.Close() 1566 if _, err = db.Exec("SELECT 1"); err == nil { 1567 t.Fatal("should cause an error.") 1568 } 1569 if driverErr, ok := err.(*SnowflakeError); ok { 1570 if driverErr.Number != tc.errorCode { // not exist error 1571 t.Errorf("got unexpected error: %v in %v", err, idx) 1572 } 1573 } 1574 }) 1575 } 1576 } 1577 1578 func TestSpecifyWarehouseDatabase(t *testing.T) { 1579 dsn := fmt.Sprintf("%s:%s@%s/%s", username, pass, host, dbname) 1580 parameters := url.Values{} 1581 parameters.Add("account", account) 1582 parameters.Add("warehouse", warehouse) 1583 // parameters.Add("role", "nopublic") TODO: create nopublic role for test 1584 if protocol != "" { 1585 parameters.Add("protocol", protocol) 1586 } 1587 db, err := sql.Open("snowflake", dsn+"?"+parameters.Encode()) 1588 if err != nil { 1589 t.Fatalf("error creating a connection object: %s", err.Error()) 1590 } 1591 defer db.Close() 1592 if _, err = db.Exec("SELECT 1"); err != nil { 1593 t.Fatalf("failed to execute a select 1: %v", err) 1594 } 1595 } 1596 1597 func TestFetchNil(t *testing.T) { 1598 runDBTest(t, func(dbt *DBTest) { 1599 rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2") 1600 defer rows.Close() 1601 var c1 sql.NullInt64 1602 var c2 sql.NullInt64 1603 1604 var results []sql.NullInt64 1605 for rows.Next() { 1606 if err := rows.Scan(&c1, &c2); err != nil { 1607 dbt.Fatal(err) 1608 } 1609 results = append(results, c1) 1610 } 1611 if results[1].Valid { 1612 t.Errorf("First element of second row must be nil (NULL). %v", results) 1613 } 1614 }) 1615 } 1616 1617 func TestPingInvalidHost(t *testing.T) { 1618 config := Config{ 1619 Account: "NOT_EXISTS", 1620 User: "BOGUS_USER", 1621 Password: "barbar", 1622 LoginTimeout: 10 * time.Second, 1623 } 1624 1625 testURL, err := DSN(&config) 1626 if err != nil { 1627 t.Fatalf("failed to parse config. config: %v, err: %v", config, err) 1628 } 1629 1630 db, err := sql.Open("snowflake", testURL) 1631 if err != nil { 1632 t.Fatalf("failed to initalize the connetion. err: %v", err) 1633 } 1634 ctx := context.Background() 1635 if err = db.PingContext(ctx); err == nil { 1636 t.Fatal("should cause an error") 1637 } 1638 if driverErr, ok := err.(*SnowflakeError); !ok || ok && driverErr.Number != ErrCodeFailedToConnect { 1639 // Failed to connect error 1640 t.Fatalf("error didn't match") 1641 } 1642 } 1643 1644 func TestOpenWithConfig(t *testing.T) { 1645 config, err := ParseDSN(dsn) 1646 if err != nil { 1647 t.Fatalf("failed to parse dsn. err: %v", err) 1648 } 1649 driver := SnowflakeDriver{} 1650 db, err := driver.OpenWithConfig(context.Background(), *config) 1651 if err != nil { 1652 t.Fatalf("failed to open with config. config: %v, err: %v", config, err) 1653 } 1654 db.Close() 1655 } 1656 1657 func TestOpenWithInvalidConfig(t *testing.T) { 1658 config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing") 1659 if err != nil { 1660 t.Fatalf("failed to parse dsn. err: %v", err) 1661 } 1662 driver := SnowflakeDriver{} 1663 _, err = driver.OpenWithConfig(context.Background(), *config) 1664 if err == nil || !strings.Contains(err.Error(), "/non-existing") { 1665 t.Fatalf("should fail on missing directory") 1666 } 1667 } 1668 1669 type CountingTransport struct { 1670 requests int 1671 } 1672 1673 func (t *CountingTransport) RoundTrip(r *http.Request) (*http.Response, error) { 1674 t.requests++ 1675 return snowflakeInsecureTransport.RoundTrip(r) 1676 } 1677 1678 func TestOpenWithTransport(t *testing.T) { 1679 config, err := ParseDSN(dsn) 1680 if err != nil { 1681 t.Fatalf("failed to parse dsn. err: %v", err) 1682 } 1683 countingTransport := CountingTransport{} 1684 var transport http.RoundTripper = &countingTransport 1685 config.Transporter = transport 1686 driver := SnowflakeDriver{} 1687 db, err := driver.OpenWithConfig(context.Background(), *config) 1688 if err != nil { 1689 t.Fatalf("failed to open with config. config: %v, err: %v", config, err) 1690 } 1691 conn := db.(*snowflakeConn) 1692 if conn.rest.Client.Transport != transport { 1693 t.Fatal("transport doesn't match") 1694 } 1695 db.Close() 1696 if countingTransport.requests == 0 { 1697 t.Fatal("transport did not receive any requests") 1698 } 1699 1700 // Test that transport override also works in insecure mode 1701 countingTransport.requests = 0 1702 config.InsecureMode = true 1703 db, err = driver.OpenWithConfig(context.Background(), *config) 1704 if err != nil { 1705 t.Fatalf("failed to open with config. config: %v, err: %v", config, err) 1706 } 1707 conn = db.(*snowflakeConn) 1708 if conn.rest.Client.Transport != transport { 1709 t.Fatal("transport doesn't match") 1710 } 1711 db.Close() 1712 if countingTransport.requests == 0 { 1713 t.Fatal("transport did not receive any requests") 1714 } 1715 } 1716 1717 func createDSNWithClientSessionKeepAlive() { 1718 dsn = fmt.Sprintf("%s:%s@%s/%s/%s", username, pass, host, dbname, schemaname) 1719 1720 parameters := url.Values{} 1721 parameters.Add("client_session_keep_alive", "true") 1722 if protocol != "" { 1723 parameters.Add("protocol", protocol) 1724 } 1725 if account != "" { 1726 parameters.Add("account", account) 1727 } 1728 if warehouse != "" { 1729 parameters.Add("warehouse", warehouse) 1730 } 1731 if rolename != "" { 1732 parameters.Add("role", rolename) 1733 } 1734 if len(parameters) > 0 { 1735 dsn += "?" + parameters.Encode() 1736 } 1737 } 1738 1739 func TestClientSessionKeepAliveParameter(t *testing.T) { 1740 // This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks 1741 // the session parameter. 1742 createDSNWithClientSessionKeepAlive() 1743 runDBTest(t, func(dbt *DBTest) { 1744 rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'") 1745 defer rows.Close() 1746 if !rows.Next() { 1747 t.Fatal("failed to get timezone.") 1748 } 1749 1750 p, err := ScanSnowflakeParameter(rows.rows) 1751 if err != nil { 1752 t.Errorf("failed to run get client_session_keep_alive value. err: %v", err) 1753 } 1754 if p.Value != "true" { 1755 t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value) 1756 } 1757 1758 rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))") 1759 defer rows2.Close() 1760 }) 1761 } 1762 1763 func TestTimePrecision(t *testing.T) { 1764 runDBTest(t, func(dbt *DBTest) { 1765 dbt.mustExec("create or replace table z3 (t1 time(5))") 1766 rows := dbt.mustQuery("select * from z3") 1767 defer rows.Close() 1768 cols, err := rows.ColumnTypes() 1769 if err != nil { 1770 t.Error(err) 1771 } 1772 if pres, _, ok := cols[0].DecimalSize(); pres != 5 || !ok { 1773 t.Fatalf("Wrong value returned. Got %v instead of 5.", pres) 1774 } 1775 }) 1776 }