github.com/snowflakedb/gosnowflake@v1.9.0/driver_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"crypto/rsa"
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"flag"
    11  	"fmt"
    12  	"net/http"
    13  	"net/url"
    14  	"os"
    15  	"os/signal"
    16  	"strings"
    17  	"syscall"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  var (
    23  	username         string
    24  	pass             string
    25  	account          string
    26  	dbname           string
    27  	schemaname       string
    28  	warehouse        string
    29  	rolename         string
    30  	dsn              string
    31  	host             string
    32  	port             string
    33  	protocol         string
    34  	customPrivateKey bool            // Whether user has specified the private key path
    35  	testPrivKey      *rsa.PrivateKey // Valid private key used for all test cases
    36  )
    37  
    38  const (
    39  	selectNumberSQL       = "SELECT %s::NUMBER(%v, %v) AS C"
    40  	selectVariousTypes    = "SELECT 1.0::NUMBER(30,2) as C1, 2::NUMBER(38,0) AS C2, 't3' AS C3, 4.2::DOUBLE AS C4, 'abcd'::BINARY AS C5, true AS C6"
    41  	selectRandomGenerator = "SELECT SEQ8(), RANDSTR(1000, RANDOM()) FROM TABLE(GENERATOR(ROWCOUNT=>%v))"
    42  	PSTLocation           = "America/Los_Angeles"
    43  )
    44  
    45  // The tests require the following parameters in the environment variables.
    46  // SNOWFLAKE_TEST_USER, SNOWFLAKE_TEST_PASSWORD, SNOWFLAKE_TEST_ACCOUNT,
    47  // SNOWFLAKE_TEST_DATABASE, SNOWFLAKE_TEST_SCHEMA, SNOWFLAKE_TEST_WAREHOUSE.
    48  // Optionally you may specify SNOWFLAKE_TEST_PROTOCOL, SNOWFLAKE_TEST_HOST
    49  // and SNOWFLAKE_TEST_PORT to specify the endpoint.
    50  func init() {
    51  	// get environment variables
    52  	env := func(key, defaultValue string) string {
    53  		if value := os.Getenv(key); value != "" {
    54  			return value
    55  		}
    56  		return defaultValue
    57  	}
    58  	username = env("SNOWFLAKE_TEST_USER", "testuser")
    59  	pass = env("SNOWFLAKE_TEST_PASSWORD", "testpassword")
    60  	account = env("SNOWFLAKE_TEST_ACCOUNT", "testaccount")
    61  	dbname = env("SNOWFLAKE_TEST_DATABASE", "testdb")
    62  	schemaname = env("SNOWFLAKE_TEST_SCHEMA", "public")
    63  	rolename = env("SNOWFLAKE_TEST_ROLE", "sysadmin")
    64  	warehouse = env("SNOWFLAKE_TEST_WAREHOUSE", "testwarehouse")
    65  
    66  	protocol = env("SNOWFLAKE_TEST_PROTOCOL", "https")
    67  	host = os.Getenv("SNOWFLAKE_TEST_HOST")
    68  	port = env("SNOWFLAKE_TEST_PORT", "443")
    69  	if host == "" {
    70  		host = fmt.Sprintf("%s.snowflakecomputing.com", account)
    71  	} else {
    72  		host = fmt.Sprintf("%s:%s", host, port)
    73  	}
    74  
    75  	setupPrivateKey()
    76  
    77  	createDSN("UTC")
    78  }
    79  
    80  func createDSN(timezone string) {
    81  	dsn = fmt.Sprintf("%s:%s@%s/%s/%s", username, pass, host, dbname, schemaname)
    82  
    83  	parameters := url.Values{}
    84  	parameters.Add("timezone", timezone)
    85  	if protocol != "" {
    86  		parameters.Add("protocol", protocol)
    87  	}
    88  	if account != "" {
    89  		parameters.Add("account", account)
    90  	}
    91  	if warehouse != "" {
    92  		parameters.Add("warehouse", warehouse)
    93  	}
    94  	if rolename != "" {
    95  		parameters.Add("role", rolename)
    96  	}
    97  
    98  	if len(parameters) > 0 {
    99  		dsn += "?" + parameters.Encode()
   100  	}
   101  }
   102  
   103  // setup creates a test schema so that all tests can run in the same schema
   104  func setup() (string, error) {
   105  	env := func(key, defaultValue string) string {
   106  		if value := os.Getenv(key); value != "" {
   107  			return value
   108  		}
   109  		return defaultValue
   110  	}
   111  
   112  	orgSchemaname := schemaname
   113  	if env("GITHUB_WORKFLOW", "") != "" {
   114  		githubRunnerID := env("RUNNER_TRACKING_ID", "GITHUB_RUNNER_ID")
   115  		githubRunnerID = strings.ReplaceAll(githubRunnerID, "-", "_")
   116  		githubSha := env("GITHUB_SHA", "GITHUB_SHA")
   117  		schemaname = fmt.Sprintf("%v_%v", githubRunnerID, githubSha)
   118  	} else {
   119  		schemaname = fmt.Sprintf("golang_%v", time.Now().UnixNano())
   120  	}
   121  	var db *sql.DB
   122  	var err error
   123  	if db, err = sql.Open("snowflake", dsn); err != nil {
   124  		return "", fmt.Errorf("failed to open db. err: %v", err)
   125  	}
   126  	defer db.Close()
   127  	if _, err = db.Exec(fmt.Sprintf("CREATE OR REPLACE SCHEMA %v", schemaname)); err != nil {
   128  		return "", fmt.Errorf("failed to create schema. %v", err)
   129  	}
   130  	createDSN("UTC")
   131  	return orgSchemaname, nil
   132  }
   133  
   134  // teardown drops the test schema
   135  func teardown() error {
   136  	var db *sql.DB
   137  	var err error
   138  	if db, err = sql.Open("snowflake", dsn); err != nil {
   139  		return fmt.Errorf("failed to open db. %v, err: %v", dsn, err)
   140  	}
   141  	defer db.Close()
   142  	if _, err = db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS %v", schemaname)); err != nil {
   143  		return fmt.Errorf("failed to create schema. %v", err)
   144  	}
   145  	return nil
   146  }
   147  
   148  func TestMain(m *testing.M) {
   149  	flag.Parse()
   150  	signal.Ignore(syscall.SIGQUIT)
   151  	if value := os.Getenv("SKIP_SETUP"); value != "" {
   152  		os.Exit(m.Run())
   153  	}
   154  
   155  	if _, err := setup(); err != nil {
   156  		panic(err)
   157  	}
   158  	ret := m.Run()
   159  	if err := teardown(); err != nil {
   160  		panic(err)
   161  	}
   162  	os.Exit(ret)
   163  }
   164  
   165  type DBTest struct {
   166  	*testing.T
   167  	conn *sql.Conn
   168  }
   169  
   170  func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *RowsExtended) {
   171  	// handler interrupt signal
   172  	ctx, cancel := context.WithCancel(context.Background())
   173  	c := make(chan os.Signal, 1)
   174  	c0 := make(chan bool, 1)
   175  	signal.Notify(c, os.Interrupt)
   176  	defer func() {
   177  		signal.Stop(c)
   178  	}()
   179  	go func() {
   180  		select {
   181  		case <-c:
   182  			fmt.Println("Caught signal, canceling...")
   183  			cancel()
   184  		case <-ctx.Done():
   185  			fmt.Println("Done")
   186  		case <-c0:
   187  		}
   188  		close(c)
   189  	}()
   190  
   191  	rs, err := dbt.conn.QueryContext(ctx, query, args...)
   192  	if err != nil {
   193  		dbt.fail("query", query, err)
   194  	}
   195  	return &RowsExtended{
   196  		rows:      rs,
   197  		closeChan: &c0,
   198  	}
   199  }
   200  
   201  func (dbt *DBTest) mustQueryContext(ctx context.Context, query string, args ...interface{}) (rows *RowsExtended) {
   202  	// handler interrupt signal
   203  	ctx, cancel := context.WithCancel(ctx)
   204  	c := make(chan os.Signal, 1)
   205  	c0 := make(chan bool, 1)
   206  	signal.Notify(c, os.Interrupt)
   207  	defer func() {
   208  		signal.Stop(c)
   209  	}()
   210  	go func() {
   211  		select {
   212  		case <-c:
   213  			fmt.Println("Caught signal, canceling...")
   214  			cancel()
   215  		case <-ctx.Done():
   216  			fmt.Println("Done")
   217  		case <-c0:
   218  		}
   219  		close(c)
   220  	}()
   221  
   222  	rs, err := dbt.conn.QueryContext(ctx, query, args...)
   223  	if err != nil {
   224  		dbt.fail("query", query, err)
   225  	}
   226  	return &RowsExtended{
   227  		rows:      rs,
   228  		closeChan: &c0,
   229  	}
   230  }
   231  
   232  func (dbt *DBTest) query(query string, args ...any) (*sql.Rows, error) {
   233  	return dbt.conn.QueryContext(context.Background(), query, args...)
   234  }
   235  
   236  func (dbt *DBTest) mustQueryAssertCount(query string, expected int, args ...interface{}) {
   237  	rows := dbt.mustQuery(query, args...)
   238  	defer rows.Close()
   239  	cnt := 0
   240  	for rows.Next() {
   241  		cnt++
   242  	}
   243  	if cnt != expected {
   244  		dbt.Fatalf("expected %v, got %v", expected, cnt)
   245  	}
   246  }
   247  
   248  func (dbt *DBTest) prepare(query string) (*sql.Stmt, error) {
   249  	return dbt.conn.PrepareContext(context.Background(), query)
   250  }
   251  
   252  func (dbt *DBTest) fail(method, query string, err error) {
   253  	if len(query) > 300 {
   254  		query = "[query too large to print]"
   255  	}
   256  	dbt.Fatalf("error on %s [%s]: %s", method, query, err.Error())
   257  }
   258  
   259  func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
   260  	return dbt.mustExecContext(context.Background(), query, args...)
   261  }
   262  
   263  func (dbt *DBTest) mustExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result) {
   264  	res, err := dbt.conn.ExecContext(ctx, query, args...)
   265  	if err != nil {
   266  		dbt.fail("exec context", query, err)
   267  	}
   268  	return res
   269  }
   270  
   271  func (dbt *DBTest) exec(query string, args ...any) (sql.Result, error) {
   272  	return dbt.conn.ExecContext(context.Background(), query, args...)
   273  }
   274  
   275  func (dbt *DBTest) mustDecimalSize(ct *sql.ColumnType) (pr int64, sc int64) {
   276  	var ok bool
   277  	pr, sc, ok = ct.DecimalSize()
   278  	if !ok {
   279  		dbt.Fatalf("failed to get decimal size. %v", ct)
   280  	}
   281  	return pr, sc
   282  }
   283  
   284  func (dbt *DBTest) mustFailDecimalSize(ct *sql.ColumnType) {
   285  	var ok bool
   286  	if _, _, ok = ct.DecimalSize(); ok {
   287  		dbt.Fatalf("should not return decimal size. %v", ct)
   288  	}
   289  }
   290  
   291  func (dbt *DBTest) mustLength(ct *sql.ColumnType) (cLen int64) {
   292  	var ok bool
   293  	cLen, ok = ct.Length()
   294  	if !ok {
   295  		dbt.Fatalf("failed to get length. %v", ct)
   296  	}
   297  	return cLen
   298  }
   299  
   300  func (dbt *DBTest) mustFailLength(ct *sql.ColumnType) {
   301  	var ok bool
   302  	if _, ok = ct.Length(); ok {
   303  		dbt.Fatalf("should not return length. %v", ct)
   304  	}
   305  }
   306  
   307  func (dbt *DBTest) mustNullable(ct *sql.ColumnType) (canNull bool) {
   308  	var ok bool
   309  	canNull, ok = ct.Nullable()
   310  	if !ok {
   311  		dbt.Fatalf("failed to get length. %v", ct)
   312  	}
   313  	return canNull
   314  }
   315  
   316  func (dbt *DBTest) mustPrepare(query string) (stmt *sql.Stmt) {
   317  	stmt, err := dbt.conn.PrepareContext(context.Background(), query)
   318  	if err != nil {
   319  		dbt.fail("prepare", query, err)
   320  	}
   321  	return stmt
   322  }
   323  
   324  type SCTest struct {
   325  	*testing.T
   326  	sc *snowflakeConn
   327  }
   328  
   329  func (sct *SCTest) fail(method, query string, err error) {
   330  	if len(query) > 300 {
   331  		query = "[query too large to print]"
   332  	}
   333  	sct.Fatalf("error on %s [%s]: %s", method, query, err.Error())
   334  }
   335  
   336  func (sct *SCTest) mustExec(query string, args []driver.Value) driver.Result {
   337  	result, err := sct.sc.Exec(query, args)
   338  	if err != nil {
   339  		sct.fail("exec", query, err)
   340  	}
   341  	return result
   342  }
   343  func (sct *SCTest) mustQuery(query string, args []driver.Value) driver.Rows {
   344  	rows, err := sct.sc.Query(query, args)
   345  	if err != nil {
   346  		sct.fail("query", query, err)
   347  	}
   348  	return rows
   349  }
   350  
   351  func (sct *SCTest) mustQueryContext(ctx context.Context, query string, args []driver.NamedValue) driver.Rows {
   352  	rows, err := sct.sc.QueryContext(ctx, query, args)
   353  	if err != nil {
   354  		sct.fail("QueryContext", query, err)
   355  	}
   356  	return rows
   357  }
   358  
   359  func (sct *SCTest) mustExecContext(ctx context.Context, query string, args []driver.NamedValue) driver.Result {
   360  	result, err := sct.sc.ExecContext(ctx, query, args)
   361  	if err != nil {
   362  		sct.fail("ExecContext", query, err)
   363  	}
   364  	return result
   365  }
   366  
   367  func runDBTest(t *testing.T, test func(dbt *DBTest)) {
   368  	conn := openConn(t)
   369  	defer conn.Close()
   370  	dbt := &DBTest{t, conn}
   371  
   372  	test(dbt)
   373  }
   374  
   375  func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
   376  	config, err := ParseDSN(dsn)
   377  	if err != nil {
   378  		t.Error(err)
   379  	}
   380  	sc, err := buildSnowflakeConn(context.Background(), *config)
   381  	if err != nil {
   382  		t.Fatal(err)
   383  	}
   384  	defer sc.Close()
   385  	if err = authenticateWithConfig(sc); err != nil {
   386  		t.Fatal(err)
   387  	}
   388  
   389  	sct := &SCTest{t, sc}
   390  
   391  	test(sct)
   392  }
   393  
   394  func runningOnAWS() bool {
   395  	return os.Getenv("CLOUD_PROVIDER") == "AWS"
   396  }
   397  
   398  func runningOnGCP() bool {
   399  	return os.Getenv("CLOUD_PROVIDER") == "GCP"
   400  }
   401  
   402  func TestBogusUserPasswordParameters(t *testing.T) {
   403  	invalidDNS := fmt.Sprintf("%s:%s@%s", "bogus", pass, host)
   404  	invalidUserPassErrorTests(invalidDNS, t)
   405  	invalidDNS = fmt.Sprintf("%s:%s@%s", username, "INVALID_PASSWORD", host)
   406  	invalidUserPassErrorTests(invalidDNS, t)
   407  }
   408  
   409  func invalidUserPassErrorTests(invalidDNS string, t *testing.T) {
   410  	parameters := url.Values{}
   411  	if protocol != "" {
   412  		parameters.Add("protocol", protocol)
   413  	}
   414  	if account != "" {
   415  		parameters.Add("account", account)
   416  	}
   417  	invalidDNS += "?" + parameters.Encode()
   418  	db, err := sql.Open("snowflake", invalidDNS)
   419  	if err != nil {
   420  		t.Fatalf("error creating a connection object: %s", err.Error())
   421  	}
   422  	// actual connection won't happen until run a query
   423  	defer db.Close()
   424  	if _, err = db.Exec("SELECT 1"); err == nil {
   425  		t.Fatal("should cause an error.")
   426  	}
   427  	if driverErr, ok := err.(*SnowflakeError); ok {
   428  		if driverErr.Number != 390100 {
   429  			t.Fatalf("wrong error code: %v", driverErr)
   430  		}
   431  		if !strings.Contains(driverErr.Error(), "390100") {
   432  			t.Fatalf("error message should included the error code. got: %v", driverErr.Error())
   433  		}
   434  	} else {
   435  		t.Fatalf("wrong error code: %v", err)
   436  	}
   437  }
   438  
   439  func TestBogusHostNameParameters(t *testing.T) {
   440  	invalidDNS := fmt.Sprintf("%s:%s@%s", username, pass, "INVALID_HOST:1234")
   441  	invalidHostErrorTests(invalidDNS, []string{"no such host", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution", "server misbehaving"}, t)
   442  	invalidDNS = fmt.Sprintf("%s:%s@%s", username, pass, "INVALID_HOST")
   443  	invalidHostErrorTests(invalidDNS, []string{"read: connection reset by peer", "EOF", "verify account name is correct", "HTTP Status: 403", "Temporary failure in name resolution", "server misbehaving"}, t)
   444  }
   445  
   446  func invalidHostErrorTests(invalidDNS string, mstr []string, t *testing.T) {
   447  	parameters := url.Values{}
   448  	if protocol != "" {
   449  		parameters.Add("protocol", protocol)
   450  	}
   451  	if account != "" {
   452  		parameters.Add("account", account)
   453  	}
   454  	parameters.Add("loginTimeout", "10")
   455  	invalidDNS += "?" + parameters.Encode()
   456  	db, err := sql.Open("snowflake", invalidDNS)
   457  	if err != nil {
   458  		t.Fatalf("error creating a connection object: %s", err.Error())
   459  	}
   460  	// actual connection won't happen until run a query
   461  	defer db.Close()
   462  	if _, err = db.Exec("SELECT 1"); err == nil {
   463  		t.Fatal("should cause an error.")
   464  	}
   465  	found := false
   466  	for _, m := range mstr {
   467  		if strings.Contains(err.Error(), m) {
   468  			found = true
   469  		}
   470  	}
   471  	if !found {
   472  		t.Fatalf("wrong error: %v", err)
   473  	}
   474  }
   475  
   476  func TestCommentOnlyQuery(t *testing.T) {
   477  	runDBTest(t, func(dbt *DBTest) {
   478  		query := "--"
   479  		// just a comment, no query
   480  		rows, err := dbt.query(query)
   481  		if err == nil {
   482  			rows.Close()
   483  			dbt.fail("query", query, err)
   484  		}
   485  		if driverErr, ok := err.(*SnowflakeError); ok {
   486  			if driverErr.Number != 900 { // syntax error
   487  				dbt.fail("query", query, err)
   488  			}
   489  		}
   490  	})
   491  }
   492  
   493  func TestEmptyQuery(t *testing.T) {
   494  	runDBTest(t, func(dbt *DBTest) {
   495  		query := "select 1 from dual where 1=0"
   496  		// just a comment, no query
   497  		rows := dbt.conn.QueryRowContext(context.Background(), query)
   498  		var v1 any
   499  		if err := rows.Scan(&v1); err != sql.ErrNoRows {
   500  			dbt.Errorf("should fail. err: %v", err)
   501  		}
   502  		rows = dbt.conn.QueryRowContext(context.Background(), query)
   503  		if err := rows.Scan(&v1); err != sql.ErrNoRows {
   504  			dbt.Errorf("should fail. err: %v", err)
   505  		}
   506  	})
   507  }
   508  
   509  func TestEmptyQueryWithRequestID(t *testing.T) {
   510  	runDBTest(t, func(dbt *DBTest) {
   511  		query := "select 1"
   512  		ctx := WithRequestID(context.Background(), NewUUID())
   513  		rows := dbt.conn.QueryRowContext(ctx, query)
   514  		var v1 interface{}
   515  		if err := rows.Scan(&v1); err != nil {
   516  			dbt.Errorf("should not have failed with valid request id. err: %v", err)
   517  		}
   518  	})
   519  }
   520  
   521  func TestCRUD(t *testing.T) {
   522  	runDBTest(t, func(dbt *DBTest) {
   523  		// Create Table
   524  		dbt.mustExec("CREATE OR REPLACE TABLE test (value BOOLEAN)")
   525  
   526  		// Test for unexpected Data
   527  		var out bool
   528  		rows := dbt.mustQuery("SELECT * FROM test")
   529  		defer rows.Close()
   530  		if rows.Next() {
   531  			dbt.Error("unexpected Data in empty table")
   532  		}
   533  
   534  		// Create Data
   535  		res := dbt.mustExec("INSERT INTO test VALUES (true)")
   536  		count, err := res.RowsAffected()
   537  		if err != nil {
   538  			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
   539  		}
   540  		if count != 1 {
   541  			dbt.Fatalf("expected 1 affected row, got %d", count)
   542  		}
   543  
   544  		id, err := res.LastInsertId()
   545  		if err != nil {
   546  			dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
   547  		}
   548  		if id != -1 {
   549  			dbt.Fatalf(
   550  				"expected InsertId -1, got %d. Snowflake doesn't support last insert ID", id)
   551  		}
   552  
   553  		// Read
   554  		rows = dbt.mustQuery("SELECT value FROM test")
   555  		defer rows.Close()
   556  		if rows.Next() {
   557  			rows.Scan(&out)
   558  			if !out {
   559  				dbt.Errorf("%t should be true", out)
   560  			}
   561  			if rows.Next() {
   562  				dbt.Error("unexpected Data")
   563  			}
   564  		} else {
   565  			dbt.Error("no Data")
   566  		}
   567  
   568  		// Update
   569  		res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true)
   570  		count, err = res.RowsAffected()
   571  		if err != nil {
   572  			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
   573  		}
   574  		if count != 1 {
   575  			dbt.Fatalf("expected 1 affected row, got %d", count)
   576  		}
   577  
   578  		// Check Update
   579  		rows = dbt.mustQuery("SELECT value FROM test")
   580  		defer rows.Close()
   581  		if rows.Next() {
   582  			rows.Scan(&out)
   583  			if out {
   584  				dbt.Errorf("%t should be true", out)
   585  			}
   586  			if rows.Next() {
   587  				dbt.Error("unexpected Data")
   588  			}
   589  		} else {
   590  			dbt.Error("no Data")
   591  		}
   592  
   593  		// Delete
   594  		res = dbt.mustExec("DELETE FROM test WHERE value = ?", false)
   595  		count, err = res.RowsAffected()
   596  		if err != nil {
   597  			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
   598  		}
   599  		if count != 1 {
   600  			dbt.Fatalf("expected 1 affected row, got %d", count)
   601  		}
   602  
   603  		// Check for unexpected rows
   604  		res = dbt.mustExec("DELETE FROM test")
   605  		count, err = res.RowsAffected()
   606  		if err != nil {
   607  			dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
   608  		}
   609  		if count != 0 {
   610  			dbt.Fatalf("expected 0 affected row, got %d", count)
   611  		}
   612  	})
   613  }
   614  
   615  func TestInt(t *testing.T) {
   616  	testInt(t, false)
   617  }
   618  
   619  func testInt(t *testing.T, json bool) {
   620  	runDBTest(t, func(dbt *DBTest) {
   621  		types := []string{"INT", "INTEGER"}
   622  		in := int64(42)
   623  		var out int64
   624  		var rows *RowsExtended
   625  
   626  		// SIGNED
   627  		for _, v := range types {
   628  			t.Run(v, func(t *testing.T) {
   629  				if json {
   630  					dbt.mustExec(forceJSON)
   631  				}
   632  				dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")")
   633  				dbt.mustExec("INSERT INTO test VALUES (?)", in)
   634  				rows = dbt.mustQuery("SELECT value FROM test")
   635  				defer rows.Close()
   636  				if rows.Next() {
   637  					rows.Scan(&out)
   638  					if in != out {
   639  						dbt.Errorf("%s: %d != %d", v, in, out)
   640  					}
   641  				} else {
   642  					dbt.Errorf("%s: no data", v)
   643  				}
   644  
   645  			})
   646  		}
   647  		dbt.mustExec("DROP TABLE IF EXISTS test")
   648  	})
   649  }
   650  
   651  func TestFloat32(t *testing.T) {
   652  	testFloat32(t, false)
   653  }
   654  
   655  func testFloat32(t *testing.T, json bool) {
   656  	runDBTest(t, func(dbt *DBTest) {
   657  		types := [2]string{"FLOAT", "DOUBLE"}
   658  		in := float32(42.23)
   659  		var out float32
   660  		var rows *RowsExtended
   661  		for _, v := range types {
   662  			t.Run(v, func(t *testing.T) {
   663  				if json {
   664  					dbt.mustExec(forceJSON)
   665  				}
   666  				dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")")
   667  				dbt.mustExec("INSERT INTO test VALUES (?)", in)
   668  				rows = dbt.mustQuery("SELECT value FROM test")
   669  				defer rows.Close()
   670  				if rows.Next() {
   671  					err := rows.Scan(&out)
   672  					if err != nil {
   673  						dbt.Errorf("failed to scan data: %v", err)
   674  					}
   675  					if in != out {
   676  						dbt.Errorf("%s: %g != %g", v, in, out)
   677  					}
   678  				} else {
   679  					dbt.Errorf("%s: no data", v)
   680  				}
   681  			})
   682  		}
   683  		dbt.mustExec("DROP TABLE IF EXISTS test")
   684  	})
   685  }
   686  
   687  func TestFloat64(t *testing.T) {
   688  	testFloat64(t, false)
   689  }
   690  
   691  func testFloat64(t *testing.T, json bool) {
   692  	runDBTest(t, func(dbt *DBTest) {
   693  		types := [2]string{"FLOAT", "DOUBLE"}
   694  		expected := 42.23
   695  		var out float64
   696  		var rows *RowsExtended
   697  		for _, v := range types {
   698  			t.Run(v, func(t *testing.T) {
   699  				if json {
   700  					dbt.mustExec(forceJSON)
   701  				}
   702  				dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")")
   703  				dbt.mustExec("INSERT INTO test VALUES (42.23)")
   704  				rows = dbt.mustQuery("SELECT value FROM test")
   705  				defer rows.Close()
   706  				if rows.Next() {
   707  					rows.Scan(&out)
   708  					if expected != out {
   709  						dbt.Errorf("%s: %g != %g", v, expected, out)
   710  					}
   711  				} else {
   712  					dbt.Errorf("%s: no data", v)
   713  				}
   714  			})
   715  		}
   716  		dbt.mustExec("DROP TABLE IF EXISTS test")
   717  	})
   718  }
   719  
   720  func TestString(t *testing.T) {
   721  	testString(t, false)
   722  }
   723  
   724  func testString(t *testing.T, json bool) {
   725  	runDBTest(t, func(dbt *DBTest) {
   726  		if json {
   727  			dbt.mustExec(forceJSON)
   728  		}
   729  		types := []string{"CHAR(255)", "VARCHAR(255)", "TEXT", "STRING"}
   730  		in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах  น่าฟังเอย"
   731  		var out string
   732  		var rows *RowsExtended
   733  
   734  		for _, v := range types {
   735  			t.Run(v, func(t *testing.T) {
   736  				dbt.mustExec("CREATE OR REPLACE TABLE test (value " + v + ")")
   737  				dbt.mustExec("INSERT INTO test VALUES (?)", in)
   738  
   739  				rows = dbt.mustQuery("SELECT value FROM test")
   740  				defer rows.Close()
   741  				if rows.Next() {
   742  					rows.Scan(&out)
   743  					if in != out {
   744  						dbt.Errorf("%s: %s != %s", v, in, out)
   745  					}
   746  				} else {
   747  					dbt.Errorf("%s: no data", v)
   748  				}
   749  			})
   750  		}
   751  		dbt.mustExec("DROP TABLE IF EXISTS test")
   752  
   753  		// BLOB (Snowflake doesn't support BLOB type but STRING covers large text data)
   754  		dbt.mustExec("CREATE OR REPLACE TABLE test (id int, value STRING)")
   755  
   756  		id := 2
   757  		in = `Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam
   758  			nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam
   759  			erat, sed diam voluptua. At vero eos et accusam et justo duo
   760  			dolores et ea rebum. Stet clita kasd gubergren, no sea takimata
   761  			sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet,
   762  			consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt
   763  			ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero
   764  			eos et accusam et justo duo dolores et ea rebum. Stet clita kasd
   765  			gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.`
   766  		dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in)
   767  
   768  		if err := dbt.conn.QueryRowContext(context.Background(), "SELECT value FROM test WHERE id = ?", id).Scan(&out); err != nil {
   769  			dbt.Fatalf("Error on BLOB-Query: %s", err.Error())
   770  		} else if out != in {
   771  			dbt.Errorf("BLOB: %s != %s", in, out)
   772  		}
   773  	})
   774  }
   775  
   776  type tcDateTimeTimestamp struct {
   777  	dbtype  string
   778  	tlayout string
   779  	tests   []timeTest
   780  }
   781  
   782  type timeTest struct {
   783  	s string    // source date time string
   784  	t time.Time // expected fetched data
   785  }
   786  
   787  func (tt timeTest) genQuery() string {
   788  	return "SELECT '%s'::%s"
   789  }
   790  
   791  func (tt timeTest) run(t *testing.T, dbt *DBTest, dbtype, tlayout string) {
   792  	var rows *RowsExtended
   793  	query := fmt.Sprintf(tt.genQuery(), tt.s, dbtype)
   794  	rows = dbt.mustQuery(query)
   795  	defer rows.Close()
   796  	var err error
   797  	if !rows.Next() {
   798  		err = rows.Err()
   799  		if err == nil {
   800  			err = fmt.Errorf("no data")
   801  		}
   802  		dbt.Errorf("%s: %s", dbtype, err)
   803  		return
   804  	}
   805  
   806  	var dst interface{}
   807  	if err = rows.Scan(&dst); err != nil {
   808  		dbt.Errorf("%s: %s", dbtype, err)
   809  		return
   810  	}
   811  	switch val := dst.(type) {
   812  	case []uint8:
   813  		str := string(val)
   814  		if str == tt.s {
   815  			return
   816  		}
   817  		dbt.Errorf("%s to string: expected %q, got %q",
   818  			dbtype,
   819  			tt.s,
   820  			str,
   821  		)
   822  	case time.Time:
   823  		if val.UnixNano() == tt.t.UnixNano() {
   824  			return
   825  		}
   826  		t.Logf("source:%v, expected: %v, got:%v", tt.s, tt.t, val)
   827  		dbt.Errorf("%s to string: expected %q, got %q",
   828  			dbtype,
   829  			tt.s,
   830  			val.Format(tlayout),
   831  		)
   832  	default:
   833  		dbt.Errorf("%s: unhandled type %T (is '%v')",
   834  			dbtype, val, val,
   835  		)
   836  	}
   837  }
   838  
   839  func TestSimpleDateTimeTimestampFetch(t *testing.T) {
   840  	testSimpleDateTimeTimestampFetch(t, false)
   841  }
   842  
   843  func testSimpleDateTimeTimestampFetch(t *testing.T, json bool) {
   844  	var scan = func(rows *RowsExtended, cd interface{}, ct interface{}, cts interface{}) {
   845  		if err := rows.Scan(cd, ct, cts); err != nil {
   846  			t.Fatal(err)
   847  		}
   848  	}
   849  	var fetchTypes = []func(*RowsExtended){
   850  		func(rows *RowsExtended) {
   851  			var cd, ct, cts time.Time
   852  			scan(rows, &cd, &ct, &cts)
   853  		},
   854  		func(rows *RowsExtended) {
   855  			var cd, ct, cts time.Time
   856  			scan(rows, &cd, &ct, &cts)
   857  		},
   858  	}
   859  	runDBTest(t, func(dbt *DBTest) {
   860  		if json {
   861  			dbt.mustExec(forceJSON)
   862  		}
   863  		for _, f := range fetchTypes {
   864  			rows := dbt.mustQuery("SELECT CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIMESTAMP()")
   865  			defer rows.Close()
   866  			if rows.Next() {
   867  				f(rows)
   868  			} else {
   869  				t.Fatal("no results")
   870  			}
   871  		}
   872  	})
   873  }
   874  
   875  func TestDateTime(t *testing.T) {
   876  	testDateTime(t, false)
   877  }
   878  
   879  func testDateTime(t *testing.T, json bool) {
   880  	afterTime := func(t time.Time, d string) time.Time {
   881  		dur, err := time.ParseDuration(d)
   882  		if err != nil {
   883  			panic(err)
   884  		}
   885  		return t.Add(dur)
   886  	}
   887  	t0 := time.Time{}
   888  	tstr0 := "0000-00-00 00:00:00.000000000"
   889  	testcases := []tcDateTimeTimestamp{
   890  		{"DATE", format[:10], []timeTest{
   891  			{t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)},
   892  			{t: time.Date(2, 8, 2, 0, 0, 0, 0, time.UTC), s: "0002-08-02"},
   893  		}},
   894  		{"TIME", format[11:19], []timeTest{
   895  			{t: afterTime(t0, "12345s")},
   896  			{t: t0, s: tstr0[11:19]},
   897  		}},
   898  		{"TIME(0)", format[11:19], []timeTest{
   899  			{t: afterTime(t0, "12345s")},
   900  			{t: t0, s: tstr0[11:19]},
   901  		}},
   902  		{"TIME(1)", format[11:21], []timeTest{
   903  			{t: afterTime(t0, "12345600ms")},
   904  			{t: t0, s: tstr0[11:21]},
   905  		}},
   906  		{"TIME(6)", format[11:], []timeTest{
   907  			{t: t0, s: tstr0[11:]},
   908  		}},
   909  		{"DATETIME", format[:19], []timeTest{
   910  			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
   911  		}},
   912  		{"DATETIME(0)", format[:21], []timeTest{
   913  			{t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)},
   914  		}},
   915  		{"DATETIME(1)", format[:21], []timeTest{
   916  			{t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)},
   917  		}},
   918  		{"DATETIME(6)", format, []timeTest{
   919  			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)},
   920  		}},
   921  		{"DATETIME(9)", format, []timeTest{
   922  			{t: time.Date(2011, 11, 20, 21, 27, 37, 123456789, time.UTC)},
   923  		}},
   924  	}
   925  	runDBTest(t, func(dbt *DBTest) {
   926  		if json {
   927  			dbt.mustExec(forceJSON)
   928  		}
   929  		for _, setups := range testcases {
   930  			t.Run(setups.dbtype, func(t *testing.T) {
   931  				for _, setup := range setups.tests {
   932  					if setup.s == "" {
   933  						// fill time string wherever Go can reliable produce it
   934  						setup.s = setup.t.Format(setups.tlayout)
   935  					}
   936  					setup.run(t, dbt, setups.dbtype, setups.tlayout)
   937  				}
   938  			})
   939  		}
   940  	})
   941  }
   942  
   943  func TestTimestampLTZ(t *testing.T) {
   944  	testTimestampLTZ(t, false)
   945  }
   946  
   947  func testTimestampLTZ(t *testing.T, json bool) {
   948  	// Set session time zone in Los Angeles, same as machine
   949  	createDSN(PSTLocation)
   950  	location, err := time.LoadLocation(PSTLocation)
   951  	if err != nil {
   952  		t.Error(err)
   953  	}
   954  	testcases := []tcDateTimeTimestamp{
   955  		{
   956  			dbtype:  "TIMESTAMP_LTZ(9)",
   957  			tlayout: format,
   958  			tests: []timeTest{
   959  				{
   960  					s: "2016-12-30 05:02:03",
   961  					t: time.Date(2016, 12, 30, 5, 2, 3, 0, location),
   962  				},
   963  				{
   964  					s: "2016-12-30 05:02:03 -00:00",
   965  					t: time.Date(2016, 12, 30, 5, 2, 3, 0, time.UTC),
   966  				},
   967  				{
   968  					s: "2017-05-12 00:51:42",
   969  					t: time.Date(2017, 5, 12, 0, 51, 42, 0, location),
   970  				},
   971  				{
   972  					s: "2017-03-12 01:00:00",
   973  					t: time.Date(2017, 3, 12, 1, 0, 0, 0, location),
   974  				},
   975  				{
   976  					s: "2017-03-13 04:00:00",
   977  					t: time.Date(2017, 3, 13, 4, 0, 0, 0, location),
   978  				},
   979  				{
   980  					s: "2017-03-13 04:00:00.123456789",
   981  					t: time.Date(2017, 3, 13, 4, 0, 0, 123456789, location),
   982  				},
   983  			},
   984  		},
   985  		{
   986  			dbtype:  "TIMESTAMP_LTZ(8)",
   987  			tlayout: format,
   988  			tests: []timeTest{
   989  				{
   990  					s: "2017-03-13 04:00:00.123456789",
   991  					t: time.Date(2017, 3, 13, 4, 0, 0, 123456780, location),
   992  				},
   993  			},
   994  		},
   995  	}
   996  	runDBTest(t, func(dbt *DBTest) {
   997  		if json {
   998  			dbt.mustExec(forceJSON)
   999  		}
  1000  		for _, setups := range testcases {
  1001  			t.Run(setups.dbtype, func(t *testing.T) {
  1002  				for _, setup := range setups.tests {
  1003  					if setup.s == "" {
  1004  						// fill time string wherever Go can reliable produce it
  1005  						setup.s = setup.t.Format(setups.tlayout)
  1006  					}
  1007  					setup.run(t, dbt, setups.dbtype, setups.tlayout)
  1008  				}
  1009  			})
  1010  		}
  1011  	})
  1012  	// Revert timezone to UTC, which is default for the test suit
  1013  	createDSN("UTC")
  1014  }
  1015  
  1016  func TestTimestampTZ(t *testing.T) {
  1017  	testTimestampTZ(t, false)
  1018  }
  1019  
  1020  func testTimestampTZ(t *testing.T, json bool) {
  1021  	sflo := func(offsets string) (loc *time.Location) {
  1022  		r, err := LocationWithOffsetString(offsets)
  1023  		if err != nil {
  1024  			return time.UTC
  1025  		}
  1026  		return r
  1027  	}
  1028  	testcases := []tcDateTimeTimestamp{
  1029  		{
  1030  			dbtype:  "TIMESTAMP_TZ(9)",
  1031  			tlayout: format,
  1032  			tests: []timeTest{
  1033  				{
  1034  					s: "2016-12-30 05:02:03 +07:00",
  1035  					t: time.Date(2016, 12, 30, 5, 2, 3, 0,
  1036  						sflo("+0700")),
  1037  				},
  1038  				{
  1039  					s: "2017-05-23 03:56:41 -09:00",
  1040  					t: time.Date(2017, 5, 23, 3, 56, 41, 0,
  1041  						sflo("-0900")),
  1042  				},
  1043  			},
  1044  		},
  1045  	}
  1046  	runDBTest(t, func(dbt *DBTest) {
  1047  		if json {
  1048  			dbt.mustExec(forceJSON)
  1049  		}
  1050  		for _, setups := range testcases {
  1051  			t.Run(setups.dbtype, func(t *testing.T) {
  1052  				for _, setup := range setups.tests {
  1053  					if setup.s == "" {
  1054  						// fill time string wherever Go can reliable produce it
  1055  						setup.s = setup.t.Format(setups.tlayout)
  1056  					}
  1057  					setup.run(t, dbt, setups.dbtype, setups.tlayout)
  1058  				}
  1059  			})
  1060  		}
  1061  	})
  1062  }
  1063  
  1064  func TestNULL(t *testing.T) {
  1065  	testNULL(t, false)
  1066  }
  1067  
  1068  func testNULL(t *testing.T, json bool) {
  1069  	runDBTest(t, func(dbt *DBTest) {
  1070  		if json {
  1071  			dbt.mustExec(forceJSON)
  1072  		}
  1073  		nullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT NULL")
  1074  		if err != nil {
  1075  			dbt.Fatal(err)
  1076  		}
  1077  		defer nullStmt.Close()
  1078  
  1079  		nonNullStmt, err := dbt.conn.PrepareContext(context.Background(), "SELECT 1")
  1080  		if err != nil {
  1081  			dbt.Fatal(err)
  1082  		}
  1083  		defer nonNullStmt.Close()
  1084  
  1085  		// NullBool
  1086  		var nb sql.NullBool
  1087  		// Invalid
  1088  		if err = nullStmt.QueryRow().Scan(&nb); err != nil {
  1089  			dbt.Fatal(err)
  1090  		}
  1091  		if nb.Valid {
  1092  			dbt.Error("valid NullBool which should be invalid")
  1093  		}
  1094  		// Valid
  1095  		if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
  1096  			dbt.Fatal(err)
  1097  		}
  1098  		if !nb.Valid {
  1099  			dbt.Error("invalid NullBool which should be valid")
  1100  		} else if !nb.Bool {
  1101  			dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
  1102  		}
  1103  
  1104  		// NullFloat64
  1105  		var nf sql.NullFloat64
  1106  		// Invalid
  1107  		if err = nullStmt.QueryRow().Scan(&nf); err != nil {
  1108  			dbt.Fatal(err)
  1109  		}
  1110  		if nf.Valid {
  1111  			dbt.Error("valid NullFloat64 which should be invalid")
  1112  		}
  1113  		// Valid
  1114  		if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
  1115  			dbt.Fatal(err)
  1116  		}
  1117  		if !nf.Valid {
  1118  			dbt.Error("invalid NullFloat64 which should be valid")
  1119  		} else if nf.Float64 != float64(1) {
  1120  			dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
  1121  		}
  1122  
  1123  		// NullInt64
  1124  		var ni sql.NullInt64
  1125  		// Invalid
  1126  		if err = nullStmt.QueryRow().Scan(&ni); err != nil {
  1127  			dbt.Fatal(err)
  1128  		}
  1129  		if ni.Valid {
  1130  			dbt.Error("valid NullInt64 which should be invalid")
  1131  		}
  1132  		// Valid
  1133  		if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
  1134  			dbt.Fatal(err)
  1135  		}
  1136  		if !ni.Valid {
  1137  			dbt.Error("invalid NullInt64 which should be valid")
  1138  		} else if ni.Int64 != int64(1) {
  1139  			dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
  1140  		}
  1141  
  1142  		// NullString
  1143  		var ns sql.NullString
  1144  		// Invalid
  1145  		if err = nullStmt.QueryRow().Scan(&ns); err != nil {
  1146  			dbt.Fatal(err)
  1147  		}
  1148  		if ns.Valid {
  1149  			dbt.Error("valid NullString which should be invalid")
  1150  		}
  1151  		// Valid
  1152  		if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
  1153  			dbt.Fatal(err)
  1154  		}
  1155  		if !ns.Valid {
  1156  			dbt.Error("invalid NullString which should be valid")
  1157  		} else if ns.String != `1` {
  1158  			dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
  1159  		}
  1160  
  1161  		// nil-bytes
  1162  		var b []byte
  1163  		// Read nil
  1164  		if err = nullStmt.QueryRow().Scan(&b); err != nil {
  1165  			dbt.Fatal(err)
  1166  		}
  1167  		if b != nil {
  1168  			dbt.Error("non-nil []byte which should be nil")
  1169  		}
  1170  		// Read non-nil
  1171  		if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
  1172  			dbt.Fatal(err)
  1173  		}
  1174  		if b == nil {
  1175  			dbt.Error("nil []byte which should be non-nil")
  1176  		}
  1177  		// Insert nil
  1178  		b = nil
  1179  		success := false
  1180  		if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ? IS NULL", b).Scan(&success); err != nil {
  1181  			dbt.Fatal(err)
  1182  		}
  1183  		if !success {
  1184  			dbt.Error("inserting []byte(nil) as NULL failed")
  1185  			t.Fatal("stopping")
  1186  		}
  1187  		// Check input==output with input==nil
  1188  		b = nil
  1189  		if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil {
  1190  			dbt.Fatal(err)
  1191  		}
  1192  		if b != nil {
  1193  			dbt.Error("non-nil echo from nil input")
  1194  		}
  1195  		// Check input==output with input!=nil
  1196  		b = []byte("")
  1197  		if err = dbt.conn.QueryRowContext(context.Background(), "SELECT ?", b).Scan(&b); err != nil {
  1198  			dbt.Fatal(err)
  1199  		}
  1200  		if b == nil {
  1201  			dbt.Error("nil echo from non-nil input")
  1202  		}
  1203  
  1204  		// Insert NULL
  1205  		dbt.mustExec("CREATE OR REPLACE TABLE test (dummmy1 int, value int, dummy2 int)")
  1206  		dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2)
  1207  
  1208  		var out interface{}
  1209  		rows := dbt.mustQuery("SELECT * FROM test")
  1210  		defer rows.Close()
  1211  		if rows.Next() {
  1212  			rows.Scan(&out)
  1213  			if out != nil {
  1214  				dbt.Errorf("%v != nil", out)
  1215  			}
  1216  		} else {
  1217  			dbt.Error("no data")
  1218  		}
  1219  	})
  1220  }
  1221  
  1222  func TestVariant(t *testing.T) {
  1223  	testVariant(t, false)
  1224  }
  1225  
  1226  func testVariant(t *testing.T, json bool) {
  1227  	runDBTest(t, func(dbt *DBTest) {
  1228  		if json {
  1229  			dbt.mustExec(forceJSON)
  1230  		}
  1231  		rows := dbt.mustQuery(`select parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]')`)
  1232  		defer rows.Close()
  1233  		var v string
  1234  		if rows.Next() {
  1235  			if err := rows.Scan(&v); err != nil {
  1236  				t.Fatal(err)
  1237  			}
  1238  		} else {
  1239  			t.Fatal("no rows")
  1240  		}
  1241  	})
  1242  }
  1243  
  1244  func TestArray(t *testing.T) {
  1245  	testArray(t, false)
  1246  }
  1247  
  1248  func testArray(t *testing.T, json bool) {
  1249  	runDBTest(t, func(dbt *DBTest) {
  1250  		if json {
  1251  			dbt.mustExec(forceJSON)
  1252  		}
  1253  		rows := dbt.mustQuery(`select as_array(parse_json('[{"id":1, "name":"test1"},{"id":2, "name":"test2"}]'))`)
  1254  		defer rows.Close()
  1255  		var v string
  1256  		if rows.Next() {
  1257  			if err := rows.Scan(&v); err != nil {
  1258  				t.Fatal(err)
  1259  			}
  1260  		} else {
  1261  			t.Fatal("no rows")
  1262  		}
  1263  	})
  1264  }
  1265  
  1266  func TestLargeSetResult(t *testing.T) {
  1267  	CustomJSONDecoderEnabled = false
  1268  	testLargeSetResult(t, 100000, false)
  1269  }
  1270  
  1271  func testLargeSetResult(t *testing.T, numrows int, json bool) {
  1272  	runDBTest(t, func(dbt *DBTest) {
  1273  		if json {
  1274  			dbt.mustExec(forceJSON)
  1275  		}
  1276  		rows := dbt.mustQuery(fmt.Sprintf(selectRandomGenerator, numrows))
  1277  		defer rows.Close()
  1278  		cnt := 0
  1279  		var idx int
  1280  		var v string
  1281  		for rows.Next() {
  1282  			if err := rows.Scan(&idx, &v); err != nil {
  1283  				t.Fatal(err)
  1284  			}
  1285  			cnt++
  1286  		}
  1287  		logger.Infof("NextResultSet: %v", rows.NextResultSet())
  1288  
  1289  		if cnt != numrows {
  1290  			dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
  1291  		}
  1292  	})
  1293  }
  1294  
  1295  func TestPingpongQuery(t *testing.T) {
  1296  	runDBTest(t, func(dbt *DBTest) {
  1297  		numrows := 1
  1298  		rows := dbt.mustQuery("SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 60))")
  1299  		defer rows.Close()
  1300  		cnt := 0
  1301  		for rows.Next() {
  1302  			cnt++
  1303  		}
  1304  		if cnt != numrows {
  1305  			dbt.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
  1306  		}
  1307  	})
  1308  }
  1309  
  1310  func TestDML(t *testing.T) {
  1311  	runDBTest(t, func(dbt *DBTest) {
  1312  		dbt.mustExec("CREATE OR REPLACE TABLE test(c1 int, c2 string)")
  1313  		if err := insertData(dbt, false); err != nil {
  1314  			dbt.Fatalf("failed to insert data: %v", err)
  1315  		}
  1316  		results, err := queryTest(dbt)
  1317  		if err != nil {
  1318  			dbt.Fatalf("failed to query test table: %v", err)
  1319  		}
  1320  		if len(*results) != 0 {
  1321  			dbt.Fatalf("number of returned data didn't match. expected 0, got: %v", len(*results))
  1322  		}
  1323  		if err = insertData(dbt, true); err != nil {
  1324  			dbt.Fatalf("failed to insert data: %v", err)
  1325  		}
  1326  		results, err = queryTest(dbt)
  1327  		if err != nil {
  1328  			dbt.Fatalf("failed to query test table: %v", err)
  1329  		}
  1330  		if len(*results) != 2 {
  1331  			dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results))
  1332  		}
  1333  	})
  1334  }
  1335  
  1336  func insertData(dbt *DBTest, commit bool) error {
  1337  	tx, err := dbt.conn.BeginTx(context.Background(), nil)
  1338  	if err != nil {
  1339  		dbt.Fatalf("failed to begin transaction: %v", err)
  1340  	}
  1341  	res, err := tx.Exec("INSERT INTO test VALUES(1, 'test1'), (2, 'test2')")
  1342  	if err != nil {
  1343  		dbt.Fatalf("failed to insert value into test: %v", err)
  1344  	}
  1345  	n, err := res.RowsAffected()
  1346  	if err != nil {
  1347  		dbt.Fatalf("failed to rows affected: %v", err)
  1348  	}
  1349  	if n != 2 {
  1350  		dbt.Fatalf("failed to insert value into test. expected: 2, got: %v", n)
  1351  	}
  1352  	results, err := queryTestTx(tx)
  1353  	if err != nil {
  1354  		dbt.Fatalf("failed to query test table: %v", err)
  1355  	}
  1356  	if len(*results) != 2 {
  1357  		dbt.Fatalf("number of returned data didn't match. expected 2, got: %v", len(*results))
  1358  	}
  1359  	if commit {
  1360  		if err = tx.Commit(); err != nil {
  1361  			return err
  1362  		}
  1363  	} else {
  1364  		if err = tx.Rollback(); err != nil {
  1365  			return err
  1366  		}
  1367  	}
  1368  	return err
  1369  }
  1370  
  1371  func queryTestTx(tx *sql.Tx) (*map[int]string, error) {
  1372  	var c1 int
  1373  	var c2 string
  1374  	rows, err := tx.Query("SELECT c1, c2 FROM test")
  1375  	if err != nil {
  1376  		return nil, err
  1377  	}
  1378  	defer rows.Close()
  1379  
  1380  	results := make(map[int]string, 2)
  1381  	for rows.Next() {
  1382  		if err = rows.Scan(&c1, &c2); err != nil {
  1383  			return nil, err
  1384  		}
  1385  		results[c1] = c2
  1386  	}
  1387  	return &results, nil
  1388  }
  1389  
  1390  func queryTest(dbt *DBTest) (*map[int]string, error) {
  1391  	var c1 int
  1392  	var c2 string
  1393  	rows, err := dbt.query("SELECT c1, c2 FROM test")
  1394  	if err != nil {
  1395  		return nil, err
  1396  	}
  1397  	defer rows.Close()
  1398  	results := make(map[int]string, 2)
  1399  	for rows.Next() {
  1400  		if err = rows.Scan(&c1, &c2); err != nil {
  1401  			return nil, err
  1402  		}
  1403  		results[c1] = c2
  1404  	}
  1405  	return &results, nil
  1406  }
  1407  
  1408  func TestCancelQuery(t *testing.T) {
  1409  	runDBTest(t, func(dbt *DBTest) {
  1410  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
  1411  		defer cancel()
  1412  
  1413  		_, err := dbt.conn.QueryContext(ctx, "SELECT DISTINCT 1 FROM TABLE(GENERATOR(TIMELIMIT=> 100))")
  1414  		if err == nil {
  1415  			dbt.Fatal("No timeout error returned")
  1416  		}
  1417  		if err.Error() != "context deadline exceeded" {
  1418  			dbt.Fatalf("Timeout error mismatch: expect %v, receive %v", context.DeadlineExceeded, err.Error())
  1419  		}
  1420  	})
  1421  }
  1422  
  1423  func TestPing(t *testing.T) {
  1424  	db := openConn(t)
  1425  	if err := db.PingContext(context.Background()); err != nil {
  1426  		t.Fatalf("failed to ping. err: %v", err)
  1427  	}
  1428  	if err := db.PingContext(context.Background()); err != nil {
  1429  		t.Fatalf("failed to ping with context. err: %v", err)
  1430  	}
  1431  	if err := db.Close(); err != nil {
  1432  		t.Fatalf("failed to close db. err: %v", err)
  1433  	}
  1434  	if err := db.PingContext(context.Background()); err == nil {
  1435  		t.Fatal("should have failed to ping")
  1436  	}
  1437  	if err := db.PingContext(context.Background()); err == nil {
  1438  		t.Fatal("should have failed to ping with context")
  1439  	}
  1440  }
  1441  
  1442  func TestDoubleDollar(t *testing.T) {
  1443  	// no escape is required for dollar signs
  1444  	runDBTest(t, func(dbt *DBTest) {
  1445  		sql := `create or replace function dateErr(I double) returns date
  1446  language javascript strict
  1447  as $$
  1448    var x = [
  1449      0, "1400000000000",
  1450      "2013-04-05",
  1451      [], [1400000000000],
  1452      "x1234",
  1453      Number.NaN, null, undefined,
  1454      {},
  1455      [1400000000000,1500000000000]
  1456    ];
  1457    return x[I];
  1458  $$
  1459  ;`
  1460  		dbt.mustExec(sql)
  1461  	})
  1462  }
  1463  
  1464  func TestTimezoneSessionParameter(t *testing.T) {
  1465  	createDSN(PSTLocation)
  1466  	conn := openConn(t)
  1467  	defer conn.Close()
  1468  
  1469  	rows, err := conn.QueryContext(context.Background(), "SHOW PARAMETERS LIKE 'TIMEZONE'")
  1470  	if err != nil {
  1471  		t.Errorf("failed to run show parameters. err: %v", err)
  1472  	}
  1473  	defer rows.Close()
  1474  	if !rows.Next() {
  1475  		t.Fatal("failed to get timezone.")
  1476  	}
  1477  
  1478  	p, err := ScanSnowflakeParameter(rows)
  1479  	if err != nil {
  1480  		t.Errorf("failed to run get timezone value. err: %v", err)
  1481  	}
  1482  	if p.Value != PSTLocation {
  1483  		t.Errorf("failed to get an expected timezone. got: %v", p.Value)
  1484  	}
  1485  	createDSN("UTC")
  1486  }
  1487  
  1488  func TestLargeSetResultCancel(t *testing.T) {
  1489  	runDBTest(t, func(dbt *DBTest) {
  1490  		c := make(chan error)
  1491  		ctx, cancel := context.WithCancel(context.Background())
  1492  		go func() {
  1493  			// attempt to run a 100 seconds query, but it should be canceled in 1 second
  1494  			timelimit := 100
  1495  			rows, err := dbt.conn.QueryContext(
  1496  				ctx,
  1497  				fmt.Sprintf("SELECT COUNT(*) FROM TABLE(GENERATOR(timelimit=>%v))", timelimit))
  1498  			if err != nil {
  1499  				c <- err
  1500  				return
  1501  			}
  1502  			defer rows.Close()
  1503  			c <- nil
  1504  		}()
  1505  		// cancel after 1 second
  1506  		time.Sleep(time.Second)
  1507  		cancel()
  1508  		ret := <-c
  1509  		if ret.Error() != "context canceled" {
  1510  			t.Fatalf("failed to cancel. err: %v", ret)
  1511  		}
  1512  		close(c)
  1513  	})
  1514  }
  1515  
  1516  func TestValidateDatabaseParameter(t *testing.T) {
  1517  	baseDSN := fmt.Sprintf("%s:%s@%s", username, pass, host)
  1518  	testcases := []struct {
  1519  		dsn       string
  1520  		params    map[string]string
  1521  		errorCode int
  1522  	}{
  1523  		{
  1524  			dsn:       baseDSN + fmt.Sprintf("/%s/%s", "NOT_EXISTS", "NOT_EXISTS"),
  1525  			errorCode: ErrObjectNotExistOrAuthorized,
  1526  		},
  1527  		{
  1528  			dsn:       baseDSN + fmt.Sprintf("/%s/%s", dbname, "NOT_EXISTS"),
  1529  			errorCode: ErrObjectNotExistOrAuthorized,
  1530  		},
  1531  		{
  1532  			dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname),
  1533  			params: map[string]string{
  1534  				"warehouse": "NOT_EXIST",
  1535  			},
  1536  			errorCode: ErrObjectNotExistOrAuthorized,
  1537  		},
  1538  		{
  1539  			dsn: baseDSN + fmt.Sprintf("/%s/%s", dbname, schemaname),
  1540  			params: map[string]string{
  1541  				"role": "NOT_EXIST",
  1542  			},
  1543  			errorCode: ErrRoleNotExist,
  1544  		},
  1545  	}
  1546  	for idx, tc := range testcases {
  1547  		t.Run(dsn, func(t *testing.T) {
  1548  			newDSN := tc.dsn
  1549  			parameters := url.Values{}
  1550  			if protocol != "" {
  1551  				parameters.Add("protocol", protocol)
  1552  			}
  1553  			if account != "" {
  1554  				parameters.Add("account", account)
  1555  			}
  1556  			for k, v := range tc.params {
  1557  				parameters.Add(k, v)
  1558  			}
  1559  			newDSN += "?" + parameters.Encode()
  1560  			db, err := sql.Open("snowflake", newDSN)
  1561  			// actual connection won't happen until run a query
  1562  			if err != nil {
  1563  				t.Fatalf("error creating a connection object: %s", err.Error())
  1564  			}
  1565  			defer db.Close()
  1566  			if _, err = db.Exec("SELECT 1"); err == nil {
  1567  				t.Fatal("should cause an error.")
  1568  			}
  1569  			if driverErr, ok := err.(*SnowflakeError); ok {
  1570  				if driverErr.Number != tc.errorCode { // not exist error
  1571  					t.Errorf("got unexpected error: %v in %v", err, idx)
  1572  				}
  1573  			}
  1574  		})
  1575  	}
  1576  }
  1577  
  1578  func TestSpecifyWarehouseDatabase(t *testing.T) {
  1579  	dsn := fmt.Sprintf("%s:%s@%s/%s", username, pass, host, dbname)
  1580  	parameters := url.Values{}
  1581  	parameters.Add("account", account)
  1582  	parameters.Add("warehouse", warehouse)
  1583  	// parameters.Add("role", "nopublic") TODO: create nopublic role for test
  1584  	if protocol != "" {
  1585  		parameters.Add("protocol", protocol)
  1586  	}
  1587  	db, err := sql.Open("snowflake", dsn+"?"+parameters.Encode())
  1588  	if err != nil {
  1589  		t.Fatalf("error creating a connection object: %s", err.Error())
  1590  	}
  1591  	defer db.Close()
  1592  	if _, err = db.Exec("SELECT 1"); err != nil {
  1593  		t.Fatalf("failed to execute a select 1: %v", err)
  1594  	}
  1595  }
  1596  
  1597  func TestFetchNil(t *testing.T) {
  1598  	runDBTest(t, func(dbt *DBTest) {
  1599  		rows := dbt.mustQuery("SELECT * FROM values(3,4),(null, 5) order by 2")
  1600  		defer rows.Close()
  1601  		var c1 sql.NullInt64
  1602  		var c2 sql.NullInt64
  1603  
  1604  		var results []sql.NullInt64
  1605  		for rows.Next() {
  1606  			if err := rows.Scan(&c1, &c2); err != nil {
  1607  				dbt.Fatal(err)
  1608  			}
  1609  			results = append(results, c1)
  1610  		}
  1611  		if results[1].Valid {
  1612  			t.Errorf("First element of second row must be nil (NULL). %v", results)
  1613  		}
  1614  	})
  1615  }
  1616  
  1617  func TestPingInvalidHost(t *testing.T) {
  1618  	config := Config{
  1619  		Account:      "NOT_EXISTS",
  1620  		User:         "BOGUS_USER",
  1621  		Password:     "barbar",
  1622  		LoginTimeout: 10 * time.Second,
  1623  	}
  1624  
  1625  	testURL, err := DSN(&config)
  1626  	if err != nil {
  1627  		t.Fatalf("failed to parse config. config: %v, err: %v", config, err)
  1628  	}
  1629  
  1630  	db, err := sql.Open("snowflake", testURL)
  1631  	if err != nil {
  1632  		t.Fatalf("failed to initalize the connetion. err: %v", err)
  1633  	}
  1634  	ctx := context.Background()
  1635  	if err = db.PingContext(ctx); err == nil {
  1636  		t.Fatal("should cause an error")
  1637  	}
  1638  	if driverErr, ok := err.(*SnowflakeError); !ok || ok && driverErr.Number != ErrCodeFailedToConnect {
  1639  		// Failed to connect error
  1640  		t.Fatalf("error didn't match")
  1641  	}
  1642  }
  1643  
  1644  func TestOpenWithConfig(t *testing.T) {
  1645  	config, err := ParseDSN(dsn)
  1646  	if err != nil {
  1647  		t.Fatalf("failed to parse dsn. err: %v", err)
  1648  	}
  1649  	driver := SnowflakeDriver{}
  1650  	db, err := driver.OpenWithConfig(context.Background(), *config)
  1651  	if err != nil {
  1652  		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
  1653  	}
  1654  	db.Close()
  1655  }
  1656  
  1657  func TestOpenWithInvalidConfig(t *testing.T) {
  1658  	config, err := ParseDSN("u:p@h?tmpDirPath=%2Fnon-existing")
  1659  	if err != nil {
  1660  		t.Fatalf("failed to parse dsn. err: %v", err)
  1661  	}
  1662  	driver := SnowflakeDriver{}
  1663  	_, err = driver.OpenWithConfig(context.Background(), *config)
  1664  	if err == nil || !strings.Contains(err.Error(), "/non-existing") {
  1665  		t.Fatalf("should fail on missing directory")
  1666  	}
  1667  }
  1668  
  1669  type CountingTransport struct {
  1670  	requests int
  1671  }
  1672  
  1673  func (t *CountingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
  1674  	t.requests++
  1675  	return snowflakeInsecureTransport.RoundTrip(r)
  1676  }
  1677  
  1678  func TestOpenWithTransport(t *testing.T) {
  1679  	config, err := ParseDSN(dsn)
  1680  	if err != nil {
  1681  		t.Fatalf("failed to parse dsn. err: %v", err)
  1682  	}
  1683  	countingTransport := CountingTransport{}
  1684  	var transport http.RoundTripper = &countingTransport
  1685  	config.Transporter = transport
  1686  	driver := SnowflakeDriver{}
  1687  	db, err := driver.OpenWithConfig(context.Background(), *config)
  1688  	if err != nil {
  1689  		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
  1690  	}
  1691  	conn := db.(*snowflakeConn)
  1692  	if conn.rest.Client.Transport != transport {
  1693  		t.Fatal("transport doesn't match")
  1694  	}
  1695  	db.Close()
  1696  	if countingTransport.requests == 0 {
  1697  		t.Fatal("transport did not receive any requests")
  1698  	}
  1699  
  1700  	// Test that transport override also works in insecure mode
  1701  	countingTransport.requests = 0
  1702  	config.InsecureMode = true
  1703  	db, err = driver.OpenWithConfig(context.Background(), *config)
  1704  	if err != nil {
  1705  		t.Fatalf("failed to open with config. config: %v, err: %v", config, err)
  1706  	}
  1707  	conn = db.(*snowflakeConn)
  1708  	if conn.rest.Client.Transport != transport {
  1709  		t.Fatal("transport doesn't match")
  1710  	}
  1711  	db.Close()
  1712  	if countingTransport.requests == 0 {
  1713  		t.Fatal("transport did not receive any requests")
  1714  	}
  1715  }
  1716  
  1717  func createDSNWithClientSessionKeepAlive() {
  1718  	dsn = fmt.Sprintf("%s:%s@%s/%s/%s", username, pass, host, dbname, schemaname)
  1719  
  1720  	parameters := url.Values{}
  1721  	parameters.Add("client_session_keep_alive", "true")
  1722  	if protocol != "" {
  1723  		parameters.Add("protocol", protocol)
  1724  	}
  1725  	if account != "" {
  1726  		parameters.Add("account", account)
  1727  	}
  1728  	if warehouse != "" {
  1729  		parameters.Add("warehouse", warehouse)
  1730  	}
  1731  	if rolename != "" {
  1732  		parameters.Add("role", rolename)
  1733  	}
  1734  	if len(parameters) > 0 {
  1735  		dsn += "?" + parameters.Encode()
  1736  	}
  1737  }
  1738  
  1739  func TestClientSessionKeepAliveParameter(t *testing.T) {
  1740  	// This test doesn't really validate the CLIENT_SESSION_KEEP_ALIVE functionality but simply checks
  1741  	// the session parameter.
  1742  	createDSNWithClientSessionKeepAlive()
  1743  	runDBTest(t, func(dbt *DBTest) {
  1744  		rows := dbt.mustQuery("SHOW PARAMETERS LIKE 'CLIENT_SESSION_KEEP_ALIVE'")
  1745  		defer rows.Close()
  1746  		if !rows.Next() {
  1747  			t.Fatal("failed to get timezone.")
  1748  		}
  1749  
  1750  		p, err := ScanSnowflakeParameter(rows.rows)
  1751  		if err != nil {
  1752  			t.Errorf("failed to run get client_session_keep_alive value. err: %v", err)
  1753  		}
  1754  		if p.Value != "true" {
  1755  			t.Fatalf("failed to get an expected client_session_keep_alive. got: %v", p.Value)
  1756  		}
  1757  
  1758  		rows2 := dbt.mustQuery("select count(*) from table(generator(timelimit=>30))")
  1759  		defer rows2.Close()
  1760  	})
  1761  }
  1762  
  1763  func TestTimePrecision(t *testing.T) {
  1764  	runDBTest(t, func(dbt *DBTest) {
  1765  		dbt.mustExec("create or replace table z3 (t1 time(5))")
  1766  		rows := dbt.mustQuery("select * from z3")
  1767  		defer rows.Close()
  1768  		cols, err := rows.ColumnTypes()
  1769  		if err != nil {
  1770  			t.Error(err)
  1771  		}
  1772  		if pres, _, ok := cols[0].DecimalSize(); pres != 5 || !ok {
  1773  			t.Fatalf("Wrong value returned. Got %v instead of 5.", pres)
  1774  		}
  1775  	})
  1776  }