github.com/snowflakedb/gosnowflake@v1.9.0/statement_test.go (about) 1 // Copyright (c) 2020-2023 Snowflake Computing Inc. All rights reserved. 2 //lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility. 3 4 package gosnowflake 5 6 import ( 7 "context" 8 "database/sql" 9 "database/sql/driver" 10 "errors" 11 "fmt" 12 "net/http" 13 "net/url" 14 "testing" 15 "time" 16 ) 17 18 func openDB(t *testing.T) *sql.DB { 19 var db *sql.DB 20 var err error 21 22 if db, err = sql.Open("snowflake", dsn); err != nil { 23 t.Fatalf("failed to open db. %v", err) 24 } 25 26 return db 27 } 28 29 func openConn(t *testing.T) *sql.Conn { 30 var db *sql.DB 31 var conn *sql.Conn 32 var err error 33 34 if db, err = sql.Open("snowflake", dsn); err != nil { 35 t.Fatalf("failed to open db. %v, err: %v", dsn, err) 36 } 37 if conn, err = db.Conn(context.Background()); err != nil { 38 t.Fatalf("failed to open connection: %v", err) 39 } 40 return conn 41 } 42 43 func TestExecStmt(t *testing.T) { 44 dqlQuery := "SELECT 1" 45 dmlQuery := "INSERT INTO TestDDLExec VALUES (1)" 46 ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)" 47 multiStmtQuery := "DELETE FROM TestDDLExec;\n" + 48 "SELECT 1;\n" + 49 "SELECT 2;" 50 ctx := context.Background() 51 multiStmtCtx, err := WithMultiStatement(ctx, 3) 52 if err != nil { 53 t.Error(err) 54 } 55 runDBTest(t, func(dbt *DBTest) { 56 dbt.mustExec(ddlQuery) 57 defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec") 58 testcases := []struct { 59 name string 60 query string 61 f func(stmt driver.Stmt) (any, error) 62 }{ 63 { 64 name: "dql Exec", 65 query: dqlQuery, 66 f: func(stmt driver.Stmt) (any, error) { 67 return stmt.Exec(nil) 68 }, 69 }, 70 { 71 name: "dql ExecContext", 72 query: dqlQuery, 73 f: func(stmt driver.Stmt) (any, error) { 74 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 75 }, 76 }, 77 { 78 name: "ddl Exec", 79 query: ddlQuery, 80 f: func(stmt driver.Stmt) (any, error) { 81 return stmt.Exec(nil) 82 }, 83 }, 84 { 85 name: "ddl ExecContext", 86 query: ddlQuery, 87 f: func(stmt driver.Stmt) (any, error) { 88 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 89 }, 90 }, 91 { 92 name: "dml Exec", 93 query: dmlQuery, 94 f: func(stmt driver.Stmt) (any, error) { 95 return stmt.Exec(nil) 96 }, 97 }, 98 { 99 name: "dml ExecContext", 100 query: dmlQuery, 101 f: func(stmt driver.Stmt) (any, error) { 102 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 103 }, 104 }, 105 { 106 name: "multistmt ExecContext", 107 query: multiStmtQuery, 108 f: func(stmt driver.Stmt) (any, error) { 109 return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil) 110 }, 111 }, 112 } 113 for _, tc := range testcases { 114 t.Run(tc.name, func(t *testing.T) { 115 err := dbt.conn.Raw(func(x any) error { 116 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) 117 if err != nil { 118 t.Error(err) 119 } 120 if stmt.(SnowflakeStmt).GetQueryID() != "" { 121 t.Error("queryId should be empty before executing any query") 122 } 123 if _, err := tc.f(stmt); err != nil { 124 t.Errorf("should have not failed to execute the query, err: %s\n", err) 125 } 126 if stmt.(SnowflakeStmt).GetQueryID() == "" { 127 t.Error("should have set the query id") 128 } 129 return nil 130 }) 131 if err != nil { 132 t.Fatal(err) 133 } 134 }) 135 } 136 }) 137 } 138 139 func TestFailedQueryIdInSnowflakeError(t *testing.T) { 140 failingQuery := "SELECTT 1" 141 failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE" 142 143 runDBTest(t, func(dbt *DBTest) { 144 testcases := []struct { 145 name string 146 query string 147 f func(dbt *DBTest) (any, error) 148 }{ 149 { 150 name: "query", 151 f: func(dbt *DBTest) (any, error) { 152 return dbt.query(failingQuery) 153 }, 154 }, 155 { 156 name: "exec", 157 f: func(dbt *DBTest) (any, error) { 158 return dbt.exec(failingExec) 159 }, 160 }, 161 } 162 163 for _, tc := range testcases { 164 t.Run(tc.name, func(t *testing.T) { 165 _, err := tc.f(dbt) 166 if err == nil { 167 t.Error("should have failed") 168 } 169 var snowflakeError *SnowflakeError 170 if !errors.As(err, &snowflakeError) { 171 t.Error("should be a SnowflakeError") 172 } 173 if snowflakeError.QueryID == "" { 174 t.Error("QueryID should be set") 175 } 176 }) 177 } 178 }) 179 } 180 181 func TestSetFailedQueryId(t *testing.T) { 182 ctx := context.Background() 183 failingQuery := "SELECTT 1" 184 failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE" 185 186 runDBTest(t, func(dbt *DBTest) { 187 testcases := []struct { 188 name string 189 query string 190 f func(stmt driver.Stmt) (any, error) 191 }{ 192 { 193 name: "query", 194 query: failingQuery, 195 f: func(stmt driver.Stmt) (any, error) { 196 return stmt.Query(nil) 197 }, 198 }, 199 { 200 name: "exec", 201 query: failingExec, 202 f: func(stmt driver.Stmt) (any, error) { 203 return stmt.Exec(nil) 204 }, 205 }, 206 { 207 name: "queryContext", 208 query: failingQuery, 209 f: func(stmt driver.Stmt) (any, error) { 210 return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 211 }, 212 }, 213 { 214 name: "execContext", 215 query: failingExec, 216 f: func(stmt driver.Stmt) (any, error) { 217 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 218 }, 219 }, 220 } 221 222 for _, tc := range testcases { 223 t.Run(tc.name, func(t *testing.T) { 224 err := dbt.conn.Raw(func(x any) error { 225 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) 226 if err != nil { 227 t.Error(err) 228 } 229 if stmt.(SnowflakeStmt).GetQueryID() != "" { 230 t.Error("queryId should be empty before executing any query") 231 } 232 if _, err := tc.f(stmt); err == nil { 233 t.Error("should have failed to execute the query") 234 } 235 if stmt.(SnowflakeStmt).GetQueryID() == "" { 236 t.Error("should have set the query id") 237 } 238 return nil 239 }) 240 if err != nil { 241 t.Fatal(err) 242 } 243 }) 244 } 245 }) 246 } 247 248 func TestAsyncFailQueryId(t *testing.T) { 249 ctx := WithAsyncMode(context.Background()) 250 runDBTest(t, func(dbt *DBTest) { 251 err := dbt.conn.Raw(func(x any) error { 252 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECTT 1") 253 if err != nil { 254 t.Error(err) 255 } 256 if stmt.(SnowflakeStmt).GetQueryID() != "" { 257 t.Error("queryId should be empty before executing any query") 258 } 259 rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 260 if err != nil { 261 t.Error("should not fail the initial request") 262 } 263 if rows.(SnowflakeRows).GetStatus() != QueryStatusInProgress { 264 t.Error("should be in progress") 265 } 266 // Wait for the query to complete 267 rows.Next(nil) 268 if rows.(SnowflakeRows).GetStatus() != QueryFailed { 269 t.Error("should have failed") 270 } 271 if rows.(SnowflakeRows).GetQueryID() != stmt.(SnowflakeStmt).GetQueryID() { 272 t.Error("last query id should be the same as rows query id") 273 } 274 return nil 275 }) 276 if err != nil { 277 t.Fatal(err) 278 } 279 }) 280 } 281 282 func TestGetQueryID(t *testing.T) { 283 ctx := context.Background() 284 conn := openConn(t) 285 defer conn.Close() 286 287 if err := conn.Raw(func(x interface{}) error { 288 rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil) 289 if err != nil { 290 return err 291 } 292 defer rows.Close() 293 294 if _, err = x.(driver.QueryerContext).QueryContext(ctx, "selectt 1", nil); err == nil { 295 t.Fatal("should have failed to execute query") 296 } 297 if driverErr, ok := err.(*SnowflakeError); ok { 298 if driverErr.Number != 1003 { 299 t.Fatalf("incorrect error code. expected: 1003, got: %v", driverErr.Number) 300 } 301 if driverErr.QueryID == "" { 302 t.Fatal("should have an associated query ID") 303 } 304 } else { 305 t.Fatal("should have been able to cast to Snowflake Error") 306 } 307 return nil 308 }); err != nil { 309 t.Fatalf("failed to prepare statement. err: %v", err) 310 } 311 } 312 313 func TestEmitQueryID(t *testing.T) { 314 queryIDChan := make(chan string, 1) 315 numrows := 100000 316 ctx := WithAsyncMode(context.Background()) 317 ctx = WithQueryIDChan(ctx, queryIDChan) 318 319 goRoutineChan := make(chan string) 320 go func(grCh chan string, qIDch chan string) { 321 queryID := <-queryIDChan 322 grCh <- queryID 323 }(goRoutineChan, queryIDChan) 324 325 cnt := 0 326 var idx int 327 var v string 328 runDBTest(t, func(dbt *DBTest) { 329 rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) 330 defer rows.Close() 331 332 for rows.Next() { 333 if err := rows.Scan(&idx, &v); err != nil { 334 t.Fatal(err) 335 } 336 cnt++ 337 } 338 logger.Infof("NextResultSet: %v", rows.NextResultSet()) 339 }) 340 341 queryID := <-goRoutineChan 342 if queryID == "" { 343 t.Fatal("expected a nonempty query ID") 344 } 345 if cnt != numrows { 346 t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) 347 } 348 } 349 350 // End-to-end test to fetch result with queryID 351 func TestE2EFetchResultByID(t *testing.T) { 352 db := openDB(t) 353 defer db.Close() 354 355 if _, err := db.Exec(`create or replace table test_fetch_result(c1 number, 356 c2 string) as select 10, 'z'`); err != nil { 357 t.Fatalf("failed to create table: %v", err) 358 } 359 360 ctx := context.Background() 361 conn, err := db.Conn(ctx) 362 if err != nil { 363 t.Error(err) 364 } 365 if err = conn.Raw(func(x interface{}) error { 366 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "select * from test_fetch_result") 367 if err != nil { 368 return err 369 } 370 371 rows1, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 372 if err != nil { 373 return err 374 } 375 qid := rows1.(SnowflakeResult).GetQueryID() 376 377 newCtx := context.WithValue(context.Background(), fetchResultByID, qid) 378 rows2, err := db.QueryContext(newCtx, "") 379 if err != nil { 380 t.Fatalf("Fetch Query Result by ID failed: %v", err) 381 } 382 var c1 sql.NullInt64 383 var c2 sql.NullString 384 for rows2.Next() { 385 err = rows2.Scan(&c1, &c2) 386 } 387 if c1.Int64 != 10 || c2.String != "z" { 388 t.Fatalf("Query result is not expected: %v", err) 389 } 390 return nil 391 }); err != nil { 392 t.Fatalf("failed to drop table: %v", err) 393 } 394 395 if _, err := db.Exec("drop table if exists test_fetch_result"); err != nil { 396 t.Fatalf("failed to drop table: %v", err) 397 } 398 } 399 400 func TestWithDescribeOnly(t *testing.T) { 401 runDBTest(t, func(dbt *DBTest) { 402 ctx := WithDescribeOnly(context.Background()) 403 rows := dbt.mustQueryContext(ctx, selectVariousTypes) 404 defer rows.Close() 405 cols, err := rows.Columns() 406 if err != nil { 407 t.Error(err) 408 } 409 types, err := rows.ColumnTypes() 410 if err != nil { 411 t.Error(err) 412 } 413 for i, col := range cols { 414 if types[i].Name() != col { 415 t.Fatalf("column name mismatch. expected: %v, got: %v", col, types[i].Name()) 416 } 417 } 418 if rows.Next() { 419 t.Fatal("there should not be any rows in describe only mode") 420 } 421 }) 422 } 423 424 func TestCallStatement(t *testing.T) { 425 runDBTest(t, func(dbt *DBTest) { 426 in1 := float64(1) 427 in2 := string("[2,3]") 428 expected := "1 \"[2,3]\" [2,3]" 429 var out string 430 431 dbt.exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true") 432 433 dbt.mustExec("create or replace procedure " + 434 "TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " + 435 "returns string language javascript as $$ " + 436 "let res = snowflake.execute({sqlText: 'select ? c1, ? c2', binds:[IN1, JSON.stringify(IN2)]}); " + 437 "res.next(); " + 438 "return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " + 439 "$$;") 440 441 stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))") 442 if err != nil { 443 dbt.Errorf("failed to prepare query: %v", err) 444 } 445 defer stmt.Close() 446 err = stmt.QueryRow(in1, in2).Scan(&out) 447 if err != nil { 448 dbt.Errorf("failed to scan: %v", err) 449 } 450 451 if expected != out { 452 dbt.Errorf("expected: %s, got: %s", expected, out) 453 } 454 455 dbt.mustExec("drop procedure if exists TEST_SP_CALL_STMT_ENABLED(float, variant)") 456 }) 457 } 458 459 func TestStmtExec(t *testing.T) { 460 ctx := context.Background() 461 conn := openConn(t) 462 defer conn.Close() 463 464 if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { 465 t.Fatalf("failed to create table: %v", err) 466 } 467 468 if err := conn.Raw(func(x interface{}) error { 469 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)") 470 if err != nil { 471 t.Error(err) 472 } 473 _, err = stmt.(*snowflakeStmt).Exec(nil) 474 if err != nil { 475 t.Error(err) 476 } 477 _, err = stmt.(*snowflakeStmt).Query(nil) 478 if err != nil { 479 t.Error(err) 480 } 481 return nil 482 }); err != nil { 483 t.Fatalf("failed to drop table: %v", err) 484 } 485 486 if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { 487 t.Fatalf("failed to drop table: %v", err) 488 } 489 } 490 491 func TestStmtExec_Error(t *testing.T) { 492 ctx := context.Background() 493 conn := openConn(t) 494 defer conn.Close() 495 496 // Create a test table 497 if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil { 498 t.Fatalf("failed to create table: %v", err) 499 } 500 501 // Attempt to execute an invalid statement 502 if err := conn.Raw(func(x interface{}) error { 503 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (?, ?)") 504 if err != nil { 505 t.Fatalf("failed to prepare statement: %v", err) 506 } 507 508 // Intentionally passing a string instead of an integer to cause an error 509 _, err = stmt.(*snowflakeStmt).Exec([]driver.Value{"invalid_data", 2}) 510 if err == nil { 511 t.Errorf("expected an error, but got none") 512 } 513 514 return nil 515 }); err != nil { 516 t.Fatalf("unexpected error: %v", err) 517 } 518 519 // Drop the test table 520 if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil { 521 t.Fatalf("failed to drop table: %v", err) 522 } 523 } 524 525 func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) { 526 return &http.Response{ 527 StatusCode: http.StatusOK, 528 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 529 }, nil 530 } 531 532 func TestUnitCheckQueryStatus(t *testing.T) { 533 sc := getDefaultSnowflakeConn() 534 ctx := context.Background() 535 qid := NewUUID() 536 537 sr := &snowflakeRestful{ 538 FuncGet: getStatusSuccessButInvalidJSONfunc, 539 TokenAccessor: getSimpleTokenAccessor(), 540 } 541 sc.rest = sr 542 _, err := sc.checkQueryStatus(ctx, qid.String()) 543 if err == nil { 544 t.Fatal("invalid json. should have failed") 545 } 546 sc.rest.FuncGet = funcGetQueryRespFail 547 _, err = sc.checkQueryStatus(ctx, qid.String()) 548 if err == nil { 549 t.Fatal("should have failed") 550 } 551 552 sc.rest.FuncGet = funcGetQueryRespError 553 _, err = sc.checkQueryStatus(ctx, qid.String()) 554 if err == nil { 555 t.Fatal("should have failed") 556 } 557 driverErr, ok := err.(*SnowflakeError) 558 if !ok { 559 t.Fatalf("should be snowflake error. err: %v", err) 560 } 561 if driverErr.Number != ErrQueryStatus { 562 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number) 563 } 564 } 565 566 func TestStatementQueryIdForQueries(t *testing.T) { 567 ctx := context.Background() 568 conn := openConn(t) 569 defer conn.Close() 570 571 testcases := []struct { 572 name string 573 f func(stmt driver.Stmt) (driver.Rows, error) 574 }{ 575 { 576 "query", 577 func(stmt driver.Stmt) (driver.Rows, error) { 578 return stmt.Query(nil) 579 }, 580 }, 581 { 582 "queryContext", 583 func(stmt driver.Stmt) (driver.Rows, error) { 584 return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 585 }, 586 }, 587 } 588 589 for _, tc := range testcases { 590 t.Run(tc.name, func(t *testing.T) { 591 err := conn.Raw(func(x any) error { 592 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") 593 if err != nil { 594 t.Fatal(err) 595 } 596 if stmt.(SnowflakeStmt).GetQueryID() != "" { 597 t.Error("queryId should be empty before executing any query") 598 } 599 firstQuery, err := tc.f(stmt) 600 if err != nil { 601 t.Fatal(err) 602 } 603 if stmt.(SnowflakeStmt).GetQueryID() == "" { 604 t.Error("queryId should not be empty after executing query") 605 } 606 if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() { 607 t.Error("queryId should be equal among query result and prepared statement") 608 } 609 secondQuery, err := tc.f(stmt) 610 if err != nil { 611 t.Fatal(err) 612 } 613 if stmt.(SnowflakeStmt).GetQueryID() == "" { 614 t.Error("queryId should not be empty after executing query") 615 } 616 if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() { 617 t.Error("queryId should be equal among query result and prepared statement") 618 } 619 return nil 620 }) 621 if err != nil { 622 t.Fatal(err) 623 } 624 }) 625 } 626 } 627 628 func TestStatementQuery(t *testing.T) { 629 ctx := context.Background() 630 conn := openConn(t) 631 defer conn.Close() 632 633 testcases := []struct { 634 name string 635 query string 636 f func(stmt driver.Stmt) (driver.Rows, error) 637 wantErr bool 638 }{ 639 { 640 "validQuery", 641 "SELECT 1", 642 func(stmt driver.Stmt) (driver.Rows, error) { 643 return stmt.Query(nil) 644 }, 645 false, 646 }, 647 { 648 "validQueryContext", 649 "SELECT 1", 650 func(stmt driver.Stmt) (driver.Rows, error) { 651 return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 652 }, 653 false, 654 }, 655 { 656 "invalidQuery", 657 "SELECT * FROM non_existing_table", 658 func(stmt driver.Stmt) (driver.Rows, error) { 659 return stmt.Query(nil) 660 }, 661 true, 662 }, 663 { 664 "invalidQueryContext", 665 "SELECT * FROM non_existing_table", 666 func(stmt driver.Stmt) (driver.Rows, error) { 667 return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil) 668 }, 669 true, 670 }, 671 } 672 673 for _, tc := range testcases { 674 t.Run(tc.name, func(t *testing.T) { 675 err := conn.Raw(func(x any) error { 676 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) 677 if err != nil { 678 if tc.wantErr { 679 return nil // expected error 680 } 681 t.Fatal(err) 682 } 683 684 _, err = tc.f(stmt) 685 if (err != nil) != tc.wantErr { 686 t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) 687 } 688 689 return nil 690 }) 691 if err != nil { 692 t.Fatal(err) 693 } 694 }) 695 } 696 } 697 698 func TestStatementQueryIdForExecs(t *testing.T) { 699 ctx := context.Background() 700 runDBTest(t, func(dbt *DBTest) { 701 dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)") 702 defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs") 703 704 testcases := []struct { 705 name string 706 f func(stmt driver.Stmt) (driver.Result, error) 707 }{ 708 { 709 "exec", 710 func(stmt driver.Stmt) (driver.Result, error) { 711 return stmt.Exec(nil) 712 }, 713 }, 714 { 715 "execContext", 716 func(stmt driver.Stmt) (driver.Result, error) { 717 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 718 }, 719 }, 720 } 721 722 for _, tc := range testcases { 723 t.Run(tc.name, func(t *testing.T) { 724 err := dbt.conn.Raw(func(x any) error { 725 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)") 726 if err != nil { 727 t.Fatal(err) 728 } 729 if stmt.(SnowflakeStmt).GetQueryID() != "" { 730 t.Error("queryId should be empty before executing any query") 731 } 732 firstExec, err := tc.f(stmt) 733 if err != nil { 734 t.Fatal(err) 735 } 736 if stmt.(SnowflakeStmt).GetQueryID() == "" { 737 t.Error("queryId should not be empty after executing query") 738 } 739 if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() { 740 t.Error("queryId should be equal among query result and prepared statement") 741 } 742 secondExec, err := tc.f(stmt) 743 if err != nil { 744 t.Fatal(err) 745 } 746 if stmt.(SnowflakeStmt).GetQueryID() == "" { 747 t.Error("queryId should not be empty after executing query") 748 } 749 if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() { 750 t.Error("queryId should be equal among query result and prepared statement") 751 } 752 return nil 753 }) 754 if err != nil { 755 t.Fatal(err) 756 } 757 }) 758 } 759 }) 760 } 761 762 func TestStatementQueryExecs(t *testing.T) { 763 ctx := context.Background() 764 runDBTest(t, func(dbt *DBTest) { 765 dbt.mustExec("CREATE TABLE TestStatementQueryExecs (v INTEGER)") 766 defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementForExecs") 767 768 testcases := []struct { 769 name string 770 query string 771 f func(stmt driver.Stmt) (driver.Result, error) 772 wantErr bool 773 }{ 774 { 775 "validExec", 776 "INSERT INTO TestStatementQueryExecs VALUES (1)", 777 func(stmt driver.Stmt) (driver.Result, error) { 778 return stmt.Exec(nil) 779 }, 780 false, 781 }, 782 { 783 "validExecContext", 784 "INSERT INTO TestStatementQueryExecs VALUES (1)", 785 func(stmt driver.Stmt) (driver.Result, error) { 786 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 787 }, 788 false, 789 }, 790 { 791 "invalidExec", 792 "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", 793 func(stmt driver.Stmt) (driver.Result, error) { 794 return stmt.Exec(nil) 795 }, 796 true, 797 }, 798 { 799 "invalidExecContext", 800 "INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')", 801 func(stmt driver.Stmt) (driver.Result, error) { 802 return stmt.(driver.StmtExecContext).ExecContext(ctx, nil) 803 }, 804 true, 805 }, 806 } 807 808 for _, tc := range testcases { 809 t.Run(tc.name, func(t *testing.T) { 810 err := dbt.conn.Raw(func(x any) error { 811 stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query) 812 if err != nil { 813 if tc.wantErr { 814 return nil // expected error 815 } 816 t.Fatal(err) 817 } 818 819 _, err = tc.f(stmt) 820 if (err != nil) != tc.wantErr { 821 t.Fatalf("error = %v, wantErr %v", err, tc.wantErr) 822 } 823 824 return nil 825 }) 826 if err != nil { 827 t.Fatal(err) 828 } 829 }) 830 } 831 }) 832 } 833 834 func TestWithQueryTag(t *testing.T) { 835 runDBTest(t, func(dbt *DBTest) { 836 testQueryTag := "TEST QUERY TAG" 837 ctx := WithQueryTag(context.Background(), testQueryTag) 838 839 // This query itself will be part of the history and will have the query tag 840 rows := dbt.mustQueryContext( 841 ctx, 842 "SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())") 843 defer rows.Close() 844 845 assertTrueF(t, rows.Next()) 846 var tag sql.NullString 847 err := rows.Scan(&tag) 848 assertNilF(t, err) 849 assertTrueF(t, tag.Valid, "no QUERY_TAG set") 850 assertEqualF(t, tag.String, testQueryTag) 851 }) 852 }