github.com/snowflakedb/gosnowflake@v1.9.0/statement_test.go (about)

     1  // Copyright (c) 2020-2023 Snowflake Computing Inc. All rights reserved.
     2  //lint:file-ignore SA1019 Ignore deprecated methods. We should leave them as-is to keep backward compatibility.
     3  
     4  package gosnowflake
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"net/http"
    13  	"net/url"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func openDB(t *testing.T) *sql.DB {
    19  	var db *sql.DB
    20  	var err error
    21  
    22  	if db, err = sql.Open("snowflake", dsn); err != nil {
    23  		t.Fatalf("failed to open db. %v", err)
    24  	}
    25  
    26  	return db
    27  }
    28  
    29  func openConn(t *testing.T) *sql.Conn {
    30  	var db *sql.DB
    31  	var conn *sql.Conn
    32  	var err error
    33  
    34  	if db, err = sql.Open("snowflake", dsn); err != nil {
    35  		t.Fatalf("failed to open db. %v, err: %v", dsn, err)
    36  	}
    37  	if conn, err = db.Conn(context.Background()); err != nil {
    38  		t.Fatalf("failed to open connection: %v", err)
    39  	}
    40  	return conn
    41  }
    42  
    43  func TestExecStmt(t *testing.T) {
    44  	dqlQuery := "SELECT 1"
    45  	dmlQuery := "INSERT INTO TestDDLExec VALUES (1)"
    46  	ddlQuery := "CREATE OR REPLACE TABLE TestDDLExec (num NUMBER)"
    47  	multiStmtQuery := "DELETE FROM TestDDLExec;\n" +
    48  		"SELECT 1;\n" +
    49  		"SELECT 2;"
    50  	ctx := context.Background()
    51  	multiStmtCtx, err := WithMultiStatement(ctx, 3)
    52  	if err != nil {
    53  		t.Error(err)
    54  	}
    55  	runDBTest(t, func(dbt *DBTest) {
    56  		dbt.mustExec(ddlQuery)
    57  		defer dbt.mustExec("DROP TABLE IF EXISTS TestDDLExec")
    58  		testcases := []struct {
    59  			name  string
    60  			query string
    61  			f     func(stmt driver.Stmt) (any, error)
    62  		}{
    63  			{
    64  				name:  "dql Exec",
    65  				query: dqlQuery,
    66  				f: func(stmt driver.Stmt) (any, error) {
    67  					return stmt.Exec(nil)
    68  				},
    69  			},
    70  			{
    71  				name:  "dql ExecContext",
    72  				query: dqlQuery,
    73  				f: func(stmt driver.Stmt) (any, error) {
    74  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
    75  				},
    76  			},
    77  			{
    78  				name:  "ddl Exec",
    79  				query: ddlQuery,
    80  				f: func(stmt driver.Stmt) (any, error) {
    81  					return stmt.Exec(nil)
    82  				},
    83  			},
    84  			{
    85  				name:  "ddl ExecContext",
    86  				query: ddlQuery,
    87  				f: func(stmt driver.Stmt) (any, error) {
    88  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
    89  				},
    90  			},
    91  			{
    92  				name:  "dml Exec",
    93  				query: dmlQuery,
    94  				f: func(stmt driver.Stmt) (any, error) {
    95  					return stmt.Exec(nil)
    96  				},
    97  			},
    98  			{
    99  				name:  "dml ExecContext",
   100  				query: dmlQuery,
   101  				f: func(stmt driver.Stmt) (any, error) {
   102  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
   103  				},
   104  			},
   105  			{
   106  				name:  "multistmt ExecContext",
   107  				query: multiStmtQuery,
   108  				f: func(stmt driver.Stmt) (any, error) {
   109  					return stmt.(driver.StmtExecContext).ExecContext(multiStmtCtx, nil)
   110  				},
   111  			},
   112  		}
   113  		for _, tc := range testcases {
   114  			t.Run(tc.name, func(t *testing.T) {
   115  				err := dbt.conn.Raw(func(x any) error {
   116  					stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
   117  					if err != nil {
   118  						t.Error(err)
   119  					}
   120  					if stmt.(SnowflakeStmt).GetQueryID() != "" {
   121  						t.Error("queryId should be empty before executing any query")
   122  					}
   123  					if _, err := tc.f(stmt); err != nil {
   124  						t.Errorf("should have not failed to execute the query, err: %s\n", err)
   125  					}
   126  					if stmt.(SnowflakeStmt).GetQueryID() == "" {
   127  						t.Error("should have set the query id")
   128  					}
   129  					return nil
   130  				})
   131  				if err != nil {
   132  					t.Fatal(err)
   133  				}
   134  			})
   135  		}
   136  	})
   137  }
   138  
   139  func TestFailedQueryIdInSnowflakeError(t *testing.T) {
   140  	failingQuery := "SELECTT 1"
   141  	failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"
   142  
   143  	runDBTest(t, func(dbt *DBTest) {
   144  		testcases := []struct {
   145  			name  string
   146  			query string
   147  			f     func(dbt *DBTest) (any, error)
   148  		}{
   149  			{
   150  				name: "query",
   151  				f: func(dbt *DBTest) (any, error) {
   152  					return dbt.query(failingQuery)
   153  				},
   154  			},
   155  			{
   156  				name: "exec",
   157  				f: func(dbt *DBTest) (any, error) {
   158  					return dbt.exec(failingExec)
   159  				},
   160  			},
   161  		}
   162  
   163  		for _, tc := range testcases {
   164  			t.Run(tc.name, func(t *testing.T) {
   165  				_, err := tc.f(dbt)
   166  				if err == nil {
   167  					t.Error("should have failed")
   168  				}
   169  				var snowflakeError *SnowflakeError
   170  				if !errors.As(err, &snowflakeError) {
   171  					t.Error("should be a SnowflakeError")
   172  				}
   173  				if snowflakeError.QueryID == "" {
   174  					t.Error("QueryID should be set")
   175  				}
   176  			})
   177  		}
   178  	})
   179  }
   180  
   181  func TestSetFailedQueryId(t *testing.T) {
   182  	ctx := context.Background()
   183  	failingQuery := "SELECTT 1"
   184  	failingExec := "INSERT 1 INTO NON_EXISTENT_TABLE"
   185  
   186  	runDBTest(t, func(dbt *DBTest) {
   187  		testcases := []struct {
   188  			name  string
   189  			query string
   190  			f     func(stmt driver.Stmt) (any, error)
   191  		}{
   192  			{
   193  				name:  "query",
   194  				query: failingQuery,
   195  				f: func(stmt driver.Stmt) (any, error) {
   196  					return stmt.Query(nil)
   197  				},
   198  			},
   199  			{
   200  				name:  "exec",
   201  				query: failingExec,
   202  				f: func(stmt driver.Stmt) (any, error) {
   203  					return stmt.Exec(nil)
   204  				},
   205  			},
   206  			{
   207  				name:  "queryContext",
   208  				query: failingQuery,
   209  				f: func(stmt driver.Stmt) (any, error) {
   210  					return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   211  				},
   212  			},
   213  			{
   214  				name:  "execContext",
   215  				query: failingExec,
   216  				f: func(stmt driver.Stmt) (any, error) {
   217  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
   218  				},
   219  			},
   220  		}
   221  
   222  		for _, tc := range testcases {
   223  			t.Run(tc.name, func(t *testing.T) {
   224  				err := dbt.conn.Raw(func(x any) error {
   225  					stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
   226  					if err != nil {
   227  						t.Error(err)
   228  					}
   229  					if stmt.(SnowflakeStmt).GetQueryID() != "" {
   230  						t.Error("queryId should be empty before executing any query")
   231  					}
   232  					if _, err := tc.f(stmt); err == nil {
   233  						t.Error("should have failed to execute the query")
   234  					}
   235  					if stmt.(SnowflakeStmt).GetQueryID() == "" {
   236  						t.Error("should have set the query id")
   237  					}
   238  					return nil
   239  				})
   240  				if err != nil {
   241  					t.Fatal(err)
   242  				}
   243  			})
   244  		}
   245  	})
   246  }
   247  
   248  func TestAsyncFailQueryId(t *testing.T) {
   249  	ctx := WithAsyncMode(context.Background())
   250  	runDBTest(t, func(dbt *DBTest) {
   251  		err := dbt.conn.Raw(func(x any) error {
   252  			stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECTT 1")
   253  			if err != nil {
   254  				t.Error(err)
   255  			}
   256  			if stmt.(SnowflakeStmt).GetQueryID() != "" {
   257  				t.Error("queryId should be empty before executing any query")
   258  			}
   259  			rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   260  			if err != nil {
   261  				t.Error("should not fail the initial request")
   262  			}
   263  			if rows.(SnowflakeRows).GetStatus() != QueryStatusInProgress {
   264  				t.Error("should be in progress")
   265  			}
   266  			// Wait for the query to complete
   267  			rows.Next(nil)
   268  			if rows.(SnowflakeRows).GetStatus() != QueryFailed {
   269  				t.Error("should have failed")
   270  			}
   271  			if rows.(SnowflakeRows).GetQueryID() != stmt.(SnowflakeStmt).GetQueryID() {
   272  				t.Error("last query id should be the same as rows query id")
   273  			}
   274  			return nil
   275  		})
   276  		if err != nil {
   277  			t.Fatal(err)
   278  		}
   279  	})
   280  }
   281  
   282  func TestGetQueryID(t *testing.T) {
   283  	ctx := context.Background()
   284  	conn := openConn(t)
   285  	defer conn.Close()
   286  
   287  	if err := conn.Raw(func(x interface{}) error {
   288  		rows, err := x.(driver.QueryerContext).QueryContext(ctx, "select 1", nil)
   289  		if err != nil {
   290  			return err
   291  		}
   292  		defer rows.Close()
   293  
   294  		if _, err = x.(driver.QueryerContext).QueryContext(ctx, "selectt 1", nil); err == nil {
   295  			t.Fatal("should have failed to execute query")
   296  		}
   297  		if driverErr, ok := err.(*SnowflakeError); ok {
   298  			if driverErr.Number != 1003 {
   299  				t.Fatalf("incorrect error code. expected: 1003, got: %v", driverErr.Number)
   300  			}
   301  			if driverErr.QueryID == "" {
   302  				t.Fatal("should have an associated query ID")
   303  			}
   304  		} else {
   305  			t.Fatal("should have been able to cast to Snowflake Error")
   306  		}
   307  		return nil
   308  	}); err != nil {
   309  		t.Fatalf("failed to prepare statement. err: %v", err)
   310  	}
   311  }
   312  
   313  func TestEmitQueryID(t *testing.T) {
   314  	queryIDChan := make(chan string, 1)
   315  	numrows := 100000
   316  	ctx := WithAsyncMode(context.Background())
   317  	ctx = WithQueryIDChan(ctx, queryIDChan)
   318  
   319  	goRoutineChan := make(chan string)
   320  	go func(grCh chan string, qIDch chan string) {
   321  		queryID := <-queryIDChan
   322  		grCh <- queryID
   323  	}(goRoutineChan, queryIDChan)
   324  
   325  	cnt := 0
   326  	var idx int
   327  	var v string
   328  	runDBTest(t, func(dbt *DBTest) {
   329  		rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
   330  		defer rows.Close()
   331  
   332  		for rows.Next() {
   333  			if err := rows.Scan(&idx, &v); err != nil {
   334  				t.Fatal(err)
   335  			}
   336  			cnt++
   337  		}
   338  		logger.Infof("NextResultSet: %v", rows.NextResultSet())
   339  	})
   340  
   341  	queryID := <-goRoutineChan
   342  	if queryID == "" {
   343  		t.Fatal("expected a nonempty query ID")
   344  	}
   345  	if cnt != numrows {
   346  		t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
   347  	}
   348  }
   349  
   350  // End-to-end test to fetch result with queryID
   351  func TestE2EFetchResultByID(t *testing.T) {
   352  	db := openDB(t)
   353  	defer db.Close()
   354  
   355  	if _, err := db.Exec(`create or replace table test_fetch_result(c1 number,
   356  		c2 string) as select 10, 'z'`); err != nil {
   357  		t.Fatalf("failed to create table: %v", err)
   358  	}
   359  
   360  	ctx := context.Background()
   361  	conn, err := db.Conn(ctx)
   362  	if err != nil {
   363  		t.Error(err)
   364  	}
   365  	if err = conn.Raw(func(x interface{}) error {
   366  		stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "select * from test_fetch_result")
   367  		if err != nil {
   368  			return err
   369  		}
   370  
   371  		rows1, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   372  		if err != nil {
   373  			return err
   374  		}
   375  		qid := rows1.(SnowflakeResult).GetQueryID()
   376  
   377  		newCtx := context.WithValue(context.Background(), fetchResultByID, qid)
   378  		rows2, err := db.QueryContext(newCtx, "")
   379  		if err != nil {
   380  			t.Fatalf("Fetch Query Result by ID failed: %v", err)
   381  		}
   382  		var c1 sql.NullInt64
   383  		var c2 sql.NullString
   384  		for rows2.Next() {
   385  			err = rows2.Scan(&c1, &c2)
   386  		}
   387  		if c1.Int64 != 10 || c2.String != "z" {
   388  			t.Fatalf("Query result is not expected: %v", err)
   389  		}
   390  		return nil
   391  	}); err != nil {
   392  		t.Fatalf("failed to drop table: %v", err)
   393  	}
   394  
   395  	if _, err := db.Exec("drop table if exists test_fetch_result"); err != nil {
   396  		t.Fatalf("failed to drop table: %v", err)
   397  	}
   398  }
   399  
   400  func TestWithDescribeOnly(t *testing.T) {
   401  	runDBTest(t, func(dbt *DBTest) {
   402  		ctx := WithDescribeOnly(context.Background())
   403  		rows := dbt.mustQueryContext(ctx, selectVariousTypes)
   404  		defer rows.Close()
   405  		cols, err := rows.Columns()
   406  		if err != nil {
   407  			t.Error(err)
   408  		}
   409  		types, err := rows.ColumnTypes()
   410  		if err != nil {
   411  			t.Error(err)
   412  		}
   413  		for i, col := range cols {
   414  			if types[i].Name() != col {
   415  				t.Fatalf("column name mismatch. expected: %v, got: %v", col, types[i].Name())
   416  			}
   417  		}
   418  		if rows.Next() {
   419  			t.Fatal("there should not be any rows in describe only mode")
   420  		}
   421  	})
   422  }
   423  
   424  func TestCallStatement(t *testing.T) {
   425  	runDBTest(t, func(dbt *DBTest) {
   426  		in1 := float64(1)
   427  		in2 := string("[2,3]")
   428  		expected := "1 \"[2,3]\" [2,3]"
   429  		var out string
   430  
   431  		dbt.exec("ALTER SESSION SET USE_STATEMENT_TYPE_CALL_FOR_STORED_PROC_CALLS = true")
   432  
   433  		dbt.mustExec("create or replace procedure " +
   434  			"TEST_SP_CALL_STMT_ENABLED(in1 float, in2 variant) " +
   435  			"returns string language javascript as $$ " +
   436  			"let res = snowflake.execute({sqlText: 'select ? c1, ? c2', binds:[IN1, JSON.stringify(IN2)]}); " +
   437  			"res.next(); " +
   438  			"return res.getColumnValueAsString(1) + ' ' + res.getColumnValueAsString(2) + ' ' + IN2; " +
   439  			"$$;")
   440  
   441  		stmt, err := dbt.conn.PrepareContext(context.Background(), "call TEST_SP_CALL_STMT_ENABLED(?, to_variant(?))")
   442  		if err != nil {
   443  			dbt.Errorf("failed to prepare query: %v", err)
   444  		}
   445  		defer stmt.Close()
   446  		err = stmt.QueryRow(in1, in2).Scan(&out)
   447  		if err != nil {
   448  			dbt.Errorf("failed to scan: %v", err)
   449  		}
   450  
   451  		if expected != out {
   452  			dbt.Errorf("expected: %s, got: %s", expected, out)
   453  		}
   454  
   455  		dbt.mustExec("drop procedure if exists TEST_SP_CALL_STMT_ENABLED(float, variant)")
   456  	})
   457  }
   458  
   459  func TestStmtExec(t *testing.T) {
   460  	ctx := context.Background()
   461  	conn := openConn(t)
   462  	defer conn.Close()
   463  
   464  	if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil {
   465  		t.Fatalf("failed to create table: %v", err)
   466  	}
   467  
   468  	if err := conn.Raw(func(x interface{}) error {
   469  		stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (1, 2)")
   470  		if err != nil {
   471  			t.Error(err)
   472  		}
   473  		_, err = stmt.(*snowflakeStmt).Exec(nil)
   474  		if err != nil {
   475  			t.Error(err)
   476  		}
   477  		_, err = stmt.(*snowflakeStmt).Query(nil)
   478  		if err != nil {
   479  			t.Error(err)
   480  		}
   481  		return nil
   482  	}); err != nil {
   483  		t.Fatalf("failed to drop table: %v", err)
   484  	}
   485  
   486  	if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil {
   487  		t.Fatalf("failed to drop table: %v", err)
   488  	}
   489  }
   490  
   491  func TestStmtExec_Error(t *testing.T) {
   492  	ctx := context.Background()
   493  	conn := openConn(t)
   494  	defer conn.Close()
   495  
   496  	// Create a test table
   497  	if _, err := conn.ExecContext(ctx, `create or replace table test_table(col1 int, col2 int)`); err != nil {
   498  		t.Fatalf("failed to create table: %v", err)
   499  	}
   500  
   501  	// Attempt to execute an invalid statement
   502  	if err := conn.Raw(func(x interface{}) error {
   503  		stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "insert into test_table values (?, ?)")
   504  		if err != nil {
   505  			t.Fatalf("failed to prepare statement: %v", err)
   506  		}
   507  
   508  		// Intentionally passing a string instead of an integer to cause an error
   509  		_, err = stmt.(*snowflakeStmt).Exec([]driver.Value{"invalid_data", 2})
   510  		if err == nil {
   511  			t.Errorf("expected an error, but got none")
   512  		}
   513  
   514  		return nil
   515  	}); err != nil {
   516  		t.Fatalf("unexpected error: %v", err)
   517  	}
   518  
   519  	// Drop the test table
   520  	if _, err := conn.ExecContext(ctx, "drop table if exists test_table"); err != nil {
   521  		t.Fatalf("failed to drop table: %v", err)
   522  	}
   523  }
   524  
   525  func getStatusSuccessButInvalidJSONfunc(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
   526  	return &http.Response{
   527  		StatusCode: http.StatusOK,
   528  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
   529  	}, nil
   530  }
   531  
   532  func TestUnitCheckQueryStatus(t *testing.T) {
   533  	sc := getDefaultSnowflakeConn()
   534  	ctx := context.Background()
   535  	qid := NewUUID()
   536  
   537  	sr := &snowflakeRestful{
   538  		FuncGet:       getStatusSuccessButInvalidJSONfunc,
   539  		TokenAccessor: getSimpleTokenAccessor(),
   540  	}
   541  	sc.rest = sr
   542  	_, err := sc.checkQueryStatus(ctx, qid.String())
   543  	if err == nil {
   544  		t.Fatal("invalid json. should have failed")
   545  	}
   546  	sc.rest.FuncGet = funcGetQueryRespFail
   547  	_, err = sc.checkQueryStatus(ctx, qid.String())
   548  	if err == nil {
   549  		t.Fatal("should have failed")
   550  	}
   551  
   552  	sc.rest.FuncGet = funcGetQueryRespError
   553  	_, err = sc.checkQueryStatus(ctx, qid.String())
   554  	if err == nil {
   555  		t.Fatal("should have failed")
   556  	}
   557  	driverErr, ok := err.(*SnowflakeError)
   558  	if !ok {
   559  		t.Fatalf("should be snowflake error. err: %v", err)
   560  	}
   561  	if driverErr.Number != ErrQueryStatus {
   562  		t.Fatalf("unexpected error code. expected: %v, got: %v", ErrQueryStatus, driverErr.Number)
   563  	}
   564  }
   565  
   566  func TestStatementQueryIdForQueries(t *testing.T) {
   567  	ctx := context.Background()
   568  	conn := openConn(t)
   569  	defer conn.Close()
   570  
   571  	testcases := []struct {
   572  		name string
   573  		f    func(stmt driver.Stmt) (driver.Rows, error)
   574  	}{
   575  		{
   576  			"query",
   577  			func(stmt driver.Stmt) (driver.Rows, error) {
   578  				return stmt.Query(nil)
   579  			},
   580  		},
   581  		{
   582  			"queryContext",
   583  			func(stmt driver.Stmt) (driver.Rows, error) {
   584  				return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   585  			},
   586  		},
   587  	}
   588  
   589  	for _, tc := range testcases {
   590  		t.Run(tc.name, func(t *testing.T) {
   591  			err := conn.Raw(func(x any) error {
   592  				stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1")
   593  				if err != nil {
   594  					t.Fatal(err)
   595  				}
   596  				if stmt.(SnowflakeStmt).GetQueryID() != "" {
   597  					t.Error("queryId should be empty before executing any query")
   598  				}
   599  				firstQuery, err := tc.f(stmt)
   600  				if err != nil {
   601  					t.Fatal(err)
   602  				}
   603  				if stmt.(SnowflakeStmt).GetQueryID() == "" {
   604  					t.Error("queryId should not be empty after executing query")
   605  				}
   606  				if stmt.(SnowflakeStmt).GetQueryID() != firstQuery.(SnowflakeRows).GetQueryID() {
   607  					t.Error("queryId should be equal among query result and prepared statement")
   608  				}
   609  				secondQuery, err := tc.f(stmt)
   610  				if err != nil {
   611  					t.Fatal(err)
   612  				}
   613  				if stmt.(SnowflakeStmt).GetQueryID() == "" {
   614  					t.Error("queryId should not be empty after executing query")
   615  				}
   616  				if stmt.(SnowflakeStmt).GetQueryID() != secondQuery.(SnowflakeRows).GetQueryID() {
   617  					t.Error("queryId should be equal among query result and prepared statement")
   618  				}
   619  				return nil
   620  			})
   621  			if err != nil {
   622  				t.Fatal(err)
   623  			}
   624  		})
   625  	}
   626  }
   627  
   628  func TestStatementQuery(t *testing.T) {
   629  	ctx := context.Background()
   630  	conn := openConn(t)
   631  	defer conn.Close()
   632  
   633  	testcases := []struct {
   634  		name    string
   635  		query   string
   636  		f       func(stmt driver.Stmt) (driver.Rows, error)
   637  		wantErr bool
   638  	}{
   639  		{
   640  			"validQuery",
   641  			"SELECT 1",
   642  			func(stmt driver.Stmt) (driver.Rows, error) {
   643  				return stmt.Query(nil)
   644  			},
   645  			false,
   646  		},
   647  		{
   648  			"validQueryContext",
   649  			"SELECT 1",
   650  			func(stmt driver.Stmt) (driver.Rows, error) {
   651  				return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   652  			},
   653  			false,
   654  		},
   655  		{
   656  			"invalidQuery",
   657  			"SELECT * FROM non_existing_table",
   658  			func(stmt driver.Stmt) (driver.Rows, error) {
   659  				return stmt.Query(nil)
   660  			},
   661  			true,
   662  		},
   663  		{
   664  			"invalidQueryContext",
   665  			"SELECT * FROM non_existing_table",
   666  			func(stmt driver.Stmt) (driver.Rows, error) {
   667  				return stmt.(driver.StmtQueryContext).QueryContext(ctx, nil)
   668  			},
   669  			true,
   670  		},
   671  	}
   672  
   673  	for _, tc := range testcases {
   674  		t.Run(tc.name, func(t *testing.T) {
   675  			err := conn.Raw(func(x any) error {
   676  				stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
   677  				if err != nil {
   678  					if tc.wantErr {
   679  						return nil // expected error
   680  					}
   681  					t.Fatal(err)
   682  				}
   683  
   684  				_, err = tc.f(stmt)
   685  				if (err != nil) != tc.wantErr {
   686  					t.Fatalf("error = %v, wantErr %v", err, tc.wantErr)
   687  				}
   688  
   689  				return nil
   690  			})
   691  			if err != nil {
   692  				t.Fatal(err)
   693  			}
   694  		})
   695  	}
   696  }
   697  
   698  func TestStatementQueryIdForExecs(t *testing.T) {
   699  	ctx := context.Background()
   700  	runDBTest(t, func(dbt *DBTest) {
   701  		dbt.mustExec("CREATE TABLE TestStatementQueryIdForExecs (v INTEGER)")
   702  		defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementQueryIdForExecs")
   703  
   704  		testcases := []struct {
   705  			name string
   706  			f    func(stmt driver.Stmt) (driver.Result, error)
   707  		}{
   708  			{
   709  				"exec",
   710  				func(stmt driver.Stmt) (driver.Result, error) {
   711  					return stmt.Exec(nil)
   712  				},
   713  			},
   714  			{
   715  				"execContext",
   716  				func(stmt driver.Stmt) (driver.Result, error) {
   717  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
   718  				},
   719  			},
   720  		}
   721  
   722  		for _, tc := range testcases {
   723  			t.Run(tc.name, func(t *testing.T) {
   724  				err := dbt.conn.Raw(func(x any) error {
   725  					stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, "INSERT INTO TestStatementQueryIdForExecs VALUES (1)")
   726  					if err != nil {
   727  						t.Fatal(err)
   728  					}
   729  					if stmt.(SnowflakeStmt).GetQueryID() != "" {
   730  						t.Error("queryId should be empty before executing any query")
   731  					}
   732  					firstExec, err := tc.f(stmt)
   733  					if err != nil {
   734  						t.Fatal(err)
   735  					}
   736  					if stmt.(SnowflakeStmt).GetQueryID() == "" {
   737  						t.Error("queryId should not be empty after executing query")
   738  					}
   739  					if stmt.(SnowflakeStmt).GetQueryID() != firstExec.(SnowflakeResult).GetQueryID() {
   740  						t.Error("queryId should be equal among query result and prepared statement")
   741  					}
   742  					secondExec, err := tc.f(stmt)
   743  					if err != nil {
   744  						t.Fatal(err)
   745  					}
   746  					if stmt.(SnowflakeStmt).GetQueryID() == "" {
   747  						t.Error("queryId should not be empty after executing query")
   748  					}
   749  					if stmt.(SnowflakeStmt).GetQueryID() != secondExec.(SnowflakeResult).GetQueryID() {
   750  						t.Error("queryId should be equal among query result and prepared statement")
   751  					}
   752  					return nil
   753  				})
   754  				if err != nil {
   755  					t.Fatal(err)
   756  				}
   757  			})
   758  		}
   759  	})
   760  }
   761  
   762  func TestStatementQueryExecs(t *testing.T) {
   763  	ctx := context.Background()
   764  	runDBTest(t, func(dbt *DBTest) {
   765  		dbt.mustExec("CREATE TABLE TestStatementQueryExecs (v INTEGER)")
   766  		defer dbt.mustExec("DROP TABLE IF EXISTS TestStatementForExecs")
   767  
   768  		testcases := []struct {
   769  			name    string
   770  			query   string
   771  			f       func(stmt driver.Stmt) (driver.Result, error)
   772  			wantErr bool
   773  		}{
   774  			{
   775  				"validExec",
   776  				"INSERT INTO TestStatementQueryExecs VALUES (1)",
   777  				func(stmt driver.Stmt) (driver.Result, error) {
   778  					return stmt.Exec(nil)
   779  				},
   780  				false,
   781  			},
   782  			{
   783  				"validExecContext",
   784  				"INSERT INTO TestStatementQueryExecs VALUES (1)",
   785  				func(stmt driver.Stmt) (driver.Result, error) {
   786  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
   787  				},
   788  				false,
   789  			},
   790  			{
   791  				"invalidExec",
   792  				"INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')",
   793  				func(stmt driver.Stmt) (driver.Result, error) {
   794  					return stmt.Exec(nil)
   795  				},
   796  				true,
   797  			},
   798  			{
   799  				"invalidExecContext",
   800  				"INSERT INTO TestStatementQueryExecs VALUES ('invalid_data')",
   801  				func(stmt driver.Stmt) (driver.Result, error) {
   802  					return stmt.(driver.StmtExecContext).ExecContext(ctx, nil)
   803  				},
   804  				true,
   805  			},
   806  		}
   807  
   808  		for _, tc := range testcases {
   809  			t.Run(tc.name, func(t *testing.T) {
   810  				err := dbt.conn.Raw(func(x any) error {
   811  					stmt, err := x.(driver.ConnPrepareContext).PrepareContext(ctx, tc.query)
   812  					if err != nil {
   813  						if tc.wantErr {
   814  							return nil // expected error
   815  						}
   816  						t.Fatal(err)
   817  					}
   818  
   819  					_, err = tc.f(stmt)
   820  					if (err != nil) != tc.wantErr {
   821  						t.Fatalf("error = %v, wantErr %v", err, tc.wantErr)
   822  					}
   823  
   824  					return nil
   825  				})
   826  				if err != nil {
   827  					t.Fatal(err)
   828  				}
   829  			})
   830  		}
   831  	})
   832  }
   833  
   834  func TestWithQueryTag(t *testing.T) {
   835  	runDBTest(t, func(dbt *DBTest) {
   836  		testQueryTag := "TEST QUERY TAG"
   837  		ctx := WithQueryTag(context.Background(), testQueryTag)
   838  
   839  		// This query itself will be part of the history and will have the query tag
   840  		rows := dbt.mustQueryContext(
   841  			ctx,
   842  			"SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())")
   843  		defer rows.Close()
   844  
   845  		assertTrueF(t, rows.Next())
   846  		var tag sql.NullString
   847  		err := rows.Scan(&tag)
   848  		assertNilF(t, err)
   849  		assertTrueF(t, tag.Valid, "no QUERY_TAG set")
   850  		assertEqualF(t, tag.String, testQueryTag)
   851  	})
   852  }