github.com/snowflakedb/gosnowflake@v1.9.0/multistatement_test.go (about)

     1  // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"net/url"
    12  	"os"
    13  	"reflect"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func TestMultiStatementExecuteNoResultSet(t *testing.T) {
    19  	ctx, _ := WithMultiStatement(context.Background(), 4)
    20  	multiStmtQuery := "begin;\n" +
    21  		"delete from test_multi_statement_txn;\n" +
    22  		"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
    23  		"commit;"
    24  
    25  	runDBTest(t, func(dbt *DBTest) {
    26  		dbt.mustExec(`create or replace table test_multi_statement_txn(c1 number, c2 string) as select 10, 'z'`)
    27  
    28  		res := dbt.mustExecContext(ctx, multiStmtQuery)
    29  		count, err := res.RowsAffected()
    30  		if err != nil {
    31  			t.Fatalf("res.RowsAffected() returned error: %v", err)
    32  		}
    33  		if count != 3 {
    34  			t.Fatalf("expected 3 affected rows, got %d", count)
    35  		}
    36  	})
    37  }
    38  
    39  func TestMultiStatementQueryResultSet(t *testing.T) {
    40  	ctx, _ := WithMultiStatement(context.Background(), 4)
    41  	multiStmtQuery := "select 123;\n" +
    42  		"select 456;\n" +
    43  		"select 789;\n" +
    44  		"select '000';"
    45  
    46  	var v1, v2, v3 int64
    47  	var v4 string
    48  
    49  	runDBTest(t, func(dbt *DBTest) {
    50  		rows := dbt.mustQueryContext(ctx, multiStmtQuery)
    51  		defer rows.Close()
    52  
    53  		// first statement
    54  		if rows.Next() {
    55  			if err := rows.Scan(&v1); err != nil {
    56  				t.Errorf("failed to scan: %#v", err)
    57  			}
    58  			if v1 != 123 {
    59  				t.Fatalf("failed to fetch. value: %v", v1)
    60  			}
    61  		} else {
    62  			t.Error("failed to query")
    63  		}
    64  
    65  		// second statement
    66  		if !rows.NextResultSet() {
    67  			t.Error("failed to retrieve next result set")
    68  		}
    69  		if rows.Next() {
    70  			if err := rows.Scan(&v2); err != nil {
    71  				t.Errorf("failed to scan: %#v", err)
    72  			}
    73  			if v2 != 456 {
    74  				t.Fatalf("failed to fetch. value: %v", v2)
    75  			}
    76  		} else {
    77  			t.Error("failed to query")
    78  		}
    79  
    80  		// third statement
    81  		if !rows.NextResultSet() {
    82  			t.Error("failed to retrieve next result set")
    83  		}
    84  		if rows.Next() {
    85  			if err := rows.Scan(&v3); err != nil {
    86  				t.Errorf("failed to scan: %#v", err)
    87  			}
    88  			if v3 != 789 {
    89  				t.Fatalf("failed to fetch. value: %v", v3)
    90  			}
    91  		} else {
    92  			t.Error("failed to query")
    93  		}
    94  
    95  		// fourth statement
    96  		if !rows.NextResultSet() {
    97  			t.Error("failed to retrieve next result set")
    98  		}
    99  		if rows.Next() {
   100  			if err := rows.Scan(&v4); err != nil {
   101  				t.Errorf("failed to scan: %#v", err)
   102  			}
   103  			if v4 != "000" {
   104  				t.Fatalf("failed to fetch. value: %v", v4)
   105  			}
   106  		} else {
   107  			t.Error("failed to query")
   108  		}
   109  	})
   110  }
   111  
   112  func TestMultiStatementExecuteResultSet(t *testing.T) {
   113  	ctx, _ := WithMultiStatement(context.Background(), 6)
   114  	multiStmtQuery := "begin;\n" +
   115  		"delete from test_multi_statement_txn_rb;\n" +
   116  		"insert into test_multi_statement_txn_rb values (1, 'a'), (2, 'b');\n" +
   117  		"select 1;\n" +
   118  		"select 2;\n" +
   119  		"rollback;"
   120  
   121  	runDBTest(t, func(dbt *DBTest) {
   122  		dbt.mustExec("drop table if exists test_multi_statement_txn_rb")
   123  		dbt.mustExec(`create or replace table test_multi_statement_txn_rb(
   124  			c1 number, c2 string) as select 10, 'z'`)
   125  		defer dbt.mustExec("drop table if exists test_multi_statement_txn_rb")
   126  
   127  		res := dbt.mustExecContext(ctx, multiStmtQuery)
   128  		count, err := res.RowsAffected()
   129  		if err != nil {
   130  			t.Fatalf("res.RowsAffected() returned error: %v", err)
   131  		}
   132  		if count != 3 {
   133  			t.Fatalf("expected 3 affected rows, got %d", count)
   134  		}
   135  	})
   136  }
   137  
   138  func TestMultiStatementQueryNoResultSet(t *testing.T) {
   139  	ctx, _ := WithMultiStatement(context.Background(), 4)
   140  	multiStmtQuery := "begin;\n" +
   141  		"delete from test_multi_statement_txn;\n" +
   142  		"insert into test_multi_statement_txn values (1, 'a'), (2, 'b');\n" +
   143  		"commit;"
   144  
   145  	runDBTest(t, func(dbt *DBTest) {
   146  		dbt.mustExec("drop table if exists test_multi_statement_txn")
   147  		dbt.mustExec(`create or replace table test_multi_statement_txn(
   148  			c1 number, c2 string) as select 10, 'z'`)
   149  		defer dbt.mustExec("drop table if exists tfmuest_multi_statement_txn")
   150  
   151  		rows := dbt.mustQueryContext(ctx, multiStmtQuery)
   152  		defer rows.Close()
   153  	})
   154  }
   155  
   156  func TestMultiStatementExecuteMix(t *testing.T) {
   157  	ctx, _ := WithMultiStatement(context.Background(), 3)
   158  	multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" +
   159  		"insert into test_multi values (1), (2);\n" +
   160  		"select cola from test_multi order by cola asc;"
   161  
   162  	runDBTest(t, func(dbt *DBTest) {
   163  		dbt.mustExec("drop table if exists test_multi_statement_txn")
   164  		dbt.mustExec(`create or replace table test_multi_statement_txn(
   165  			c1 number, c2 string) as select 10, 'z'`)
   166  		defer dbt.mustExec("drop table if exists test_multi_statement_txn")
   167  
   168  		res := dbt.mustExecContext(ctx, multiStmtQuery)
   169  		count, err := res.RowsAffected()
   170  		if err != nil {
   171  			t.Fatalf("res.RowsAffected() returned error: %v", err)
   172  		}
   173  		if count != 2 {
   174  			t.Fatalf("expected 2 affected rows, got %d", count)
   175  		}
   176  	})
   177  }
   178  
   179  func TestMultiStatementQueryMix(t *testing.T) {
   180  	ctx, _ := WithMultiStatement(context.Background(), 3)
   181  	multiStmtQuery := "create or replace temporary table test_multi (cola int);\n" +
   182  		"insert into test_multi values (1), (2);\n" +
   183  		"select cola from test_multi order by cola asc;"
   184  
   185  	var count, v int
   186  	runDBTest(t, func(dbt *DBTest) {
   187  		dbt.mustExec("drop table if exists test_multi_statement_txn")
   188  		dbt.mustExec(`create or replace table test_multi_statement_txn(
   189  			c1 number, c2 string) as select 10, 'z'`)
   190  		defer dbt.mustExec("drop table if exists test_multi_statement_txn")
   191  
   192  		rows := dbt.mustQueryContext(ctx, multiStmtQuery)
   193  		defer rows.Close()
   194  
   195  		// first statement
   196  		if !rows.Next() {
   197  			t.Error("failed to query")
   198  		}
   199  
   200  		// second statement
   201  		rows.NextResultSet()
   202  		if rows.Next() {
   203  			if err := rows.Scan(&count); err != nil {
   204  				t.Errorf("failed to scan: %#v", err)
   205  			}
   206  			if count != 2 {
   207  				t.Fatalf("expected 2 affected rows, got %d", count)
   208  			}
   209  		}
   210  
   211  		expected := 1
   212  		// third statement
   213  		rows.NextResultSet()
   214  		for rows.Next() {
   215  			if err := rows.Scan(&v); err != nil {
   216  				t.Errorf("failed to scan: %#v", err)
   217  			}
   218  			if v != expected {
   219  				t.Fatalf("failed to fetch. value: %v", v)
   220  			}
   221  			expected++
   222  		}
   223  	})
   224  }
   225  
   226  func TestMultiStatementCountZero(t *testing.T) {
   227  	ctx, _ := WithMultiStatement(context.Background(), 0)
   228  	var v1 int
   229  	var v2 string
   230  	var v3 float64
   231  	var v4 bool
   232  
   233  	runDBTest(t, func(dbt *DBTest) {
   234  		// first query
   235  		multiStmtQuery1 := "select 123;\n" +
   236  			"select '456';"
   237  		rows1 := dbt.mustQueryContext(ctx, multiStmtQuery1)
   238  		defer rows1.Close()
   239  		// first statement
   240  		if rows1.Next() {
   241  			if err := rows1.Scan(&v1); err != nil {
   242  				t.Errorf("failed to scan: %#v", err)
   243  			}
   244  			if v1 != 123 {
   245  				t.Fatalf("failed to fetch. value: %v", v1)
   246  			}
   247  		} else {
   248  			t.Error("failed to query")
   249  		}
   250  
   251  		// second statement
   252  		if !rows1.NextResultSet() {
   253  			t.Error("failed to retrieve next result set")
   254  		}
   255  		if rows1.Next() {
   256  			if err := rows1.Scan(&v2); err != nil {
   257  				t.Errorf("failed to scan: %#v", err)
   258  			}
   259  			if v2 != "456" {
   260  				t.Fatalf("failed to fetch. value: %v", v2)
   261  			}
   262  		} else {
   263  			t.Error("failed to query")
   264  		}
   265  
   266  		// second query
   267  		multiStmtQuery2 := "select 789;\n" +
   268  			"select 'foo';\n" +
   269  			"select 0.123;\n" +
   270  			"select true;"
   271  		rows2 := dbt.mustQueryContext(ctx, multiStmtQuery2)
   272  		defer rows2.Close()
   273  		// first statement
   274  		if rows2.Next() {
   275  			if err := rows2.Scan(&v1); err != nil {
   276  				t.Errorf("failed to scan: %#v", err)
   277  			}
   278  			if v1 != 789 {
   279  				t.Fatalf("failed to fetch. value: %v", v1)
   280  			}
   281  		} else {
   282  			t.Error("failed to query")
   283  		}
   284  
   285  		// second statement
   286  		if !rows2.NextResultSet() {
   287  			t.Error("failed to retrieve next result set")
   288  		}
   289  		if rows2.Next() {
   290  			if err := rows2.Scan(&v2); err != nil {
   291  				t.Errorf("failed to scan: %#v", err)
   292  			}
   293  			if v2 != "foo" {
   294  				t.Fatalf("failed to fetch. value: %v", v2)
   295  			}
   296  		} else {
   297  			t.Error("failed to query")
   298  		}
   299  
   300  		// third statement
   301  		if !rows2.NextResultSet() {
   302  			t.Error("failed to retrieve next result set")
   303  		}
   304  		if rows2.Next() {
   305  			if err := rows2.Scan(&v3); err != nil {
   306  				t.Errorf("failed to scan: %#v", err)
   307  			}
   308  			if v3 != 0.123 {
   309  				t.Fatalf("failed to fetch. value: %v", v3)
   310  			}
   311  		} else {
   312  			t.Error("failed to query")
   313  		}
   314  
   315  		// fourth statement
   316  		if !rows2.NextResultSet() {
   317  			t.Error("failed to retrieve next result set")
   318  		}
   319  		if rows2.Next() {
   320  			if err := rows2.Scan(&v4); err != nil {
   321  				t.Errorf("failed to scan: %#v", err)
   322  			}
   323  			if v4 != true {
   324  				t.Fatalf("failed to fetch. value: %v", v4)
   325  			}
   326  		} else {
   327  			t.Error("failed to query")
   328  		}
   329  	})
   330  }
   331  
   332  func TestMultiStatementCountMismatch(t *testing.T) {
   333  	conn := openConn(t)
   334  	defer conn.Close()
   335  
   336  	multiStmtQuery := "select 123;\n" +
   337  		"select 456;\n" +
   338  		"select 789;\n" +
   339  		"select '000';"
   340  
   341  	ctx, _ := WithMultiStatement(context.Background(), 3)
   342  	if _, err := conn.QueryContext(ctx, multiStmtQuery); err == nil {
   343  		t.Fatal("should have failed to query multiple statements")
   344  	}
   345  }
   346  
   347  func TestMultiStatementVaryingColumnCount(t *testing.T) {
   348  	multiStmtQuery := "select c1 from test_tbl;\n" +
   349  		"select c1,c2 from test_tbl;"
   350  	ctx, _ := WithMultiStatement(context.Background(), 0)
   351  
   352  	var v1, v2 int
   353  	runDBTest(t, func(dbt *DBTest) {
   354  		dbt.mustExec("create or replace table test_tbl(c1 int, c2 int)")
   355  		dbt.mustExec("insert into test_tbl values(1, 0)")
   356  		defer dbt.mustExec("drop table if exists test_tbl")
   357  
   358  		rows := dbt.mustQueryContext(ctx, multiStmtQuery)
   359  		defer rows.Close()
   360  
   361  		if rows.Next() {
   362  			if err := rows.Scan(&v1); err != nil {
   363  				t.Errorf("failed to scan: %#v", err)
   364  			}
   365  			if v1 != 1 {
   366  				t.Fatalf("failed to fetch. value: %v", v1)
   367  			}
   368  		} else {
   369  			t.Error("failed to query")
   370  		}
   371  
   372  		if !rows.NextResultSet() {
   373  			t.Error("failed to retrieve next result set")
   374  		}
   375  
   376  		if rows.Next() {
   377  			if err := rows.Scan(&v1, &v2); err != nil {
   378  				t.Errorf("failed to scan: %#v", err)
   379  			}
   380  			if v1 != 1 || v2 != 0 {
   381  				t.Fatalf("failed to fetch. value: %v, %v", v1, v2)
   382  			}
   383  		} else {
   384  			t.Error("failed to query")
   385  		}
   386  	})
   387  }
   388  
   389  // The total completion time should be similar to the duration of the query on Snowflake UI.
   390  func TestMultiStatementExecutePerformance(t *testing.T) {
   391  	ctx, _ := WithMultiStatement(context.Background(), 100)
   392  	runDBTest(t, func(dbt *DBTest) {
   393  		file, err := os.Open("test_data/multistatements.sql")
   394  		if err != nil {
   395  			t.Fatalf("failed opening file: %s", err)
   396  		}
   397  		defer file.Close()
   398  		statements, err := io.ReadAll(file)
   399  		if err != nil {
   400  			t.Fatalf("failed reading file: %s", err)
   401  		}
   402  
   403  		sql := string(statements)
   404  
   405  		start := time.Now()
   406  		res := dbt.mustExecContext(ctx, sql)
   407  		duration := time.Since(start)
   408  
   409  		count, err := res.RowsAffected()
   410  		if err != nil {
   411  			t.Fatalf("res.RowsAffected() returned error: %v", err)
   412  		}
   413  		if count != 0 {
   414  			t.Fatalf("expected 0 affected rows, got %d", count)
   415  		}
   416  		t.Logf("The total completion time was %v", duration)
   417  
   418  		file, err = os.Open("test_data/multistatements_drop.sql")
   419  		if err != nil {
   420  			t.Fatalf("failed opening file: %s", err)
   421  		}
   422  		defer file.Close()
   423  		statements, err = io.ReadAll(file)
   424  		if err != nil {
   425  			t.Fatalf("failed reading file: %s", err)
   426  		}
   427  		sql = string(statements)
   428  		dbt.mustExecContext(ctx, sql)
   429  	})
   430  }
   431  
   432  func TestUnitGetChildResults(t *testing.T) {
   433  	testcases := []struct {
   434  		ids   string
   435  		types string
   436  		out   []childResult
   437  	}{
   438  		{"", "", nil},
   439  		{"", "4096", nil},
   440  		{"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba", "12544,12544", []childResult{
   441  			{"01aa3265-0405-ab7c-0000-53b106343aba", "12544"},
   442  			{"02aa3265-0405-ab7c-0000-53b106343aba", "12544"}}},
   443  		{"01aa3265-0405-ab7c-0000-53b106343aba,02aa3265-0405-ab7c-0000-53b106343aba,03aa3265-0405-ab7c-0000-53b106343aba", "25344,4096,12544", []childResult{
   444  			{"01aa3265-0405-ab7c-0000-53b106343aba", "25344"},
   445  			{"02aa3265-0405-ab7c-0000-53b106343aba", "4096"},
   446  			{"03aa3265-0405-ab7c-0000-53b106343aba", "12544"}}},
   447  	}
   448  	for _, test := range testcases {
   449  		t.Run(test.ids, func(t *testing.T) {
   450  			res := getChildResults(test.ids, test.types)
   451  			if !reflect.DeepEqual(res, test.out) {
   452  				t.Fatalf("Child result should be equal, expected %v, actual %v", res, test.out)
   453  			}
   454  		})
   455  	}
   456  }
   457  
   458  func funcGetQueryRespFail(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
   459  	return nil, errors.New("failed to get query response")
   460  }
   461  
   462  func funcGetQueryRespError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
   463  	dd := &execResponseData{}
   464  	er := &execResponse{
   465  		Data:    *dd,
   466  		Message: "query failed",
   467  		Code:    "261000",
   468  		Success: false,
   469  	}
   470  	ba, err := json.Marshal(er)
   471  	if err != nil {
   472  		panic(err)
   473  	}
   474  
   475  	return &http.Response{
   476  		StatusCode: http.StatusOK,
   477  		Body:       &fakeResponseBody{body: ba},
   478  	}, nil
   479  }
   480  
   481  func TestUnitHandleMultiExec(t *testing.T) {
   482  	runSnowflakeConnTest(t, func(sct *SCTest) {
   483  		data := execResponseData{
   484  			ResultIDs:   "",
   485  			ResultTypes: "",
   486  		}
   487  		_, err := sct.sc.handleMultiExec(context.Background(), data)
   488  		if err == nil {
   489  			t.Fatalf("should have failed")
   490  		}
   491  		driverErr, ok := err.(*SnowflakeError)
   492  		if !ok {
   493  			t.Fatalf("should be snowflake error. err: %v", err)
   494  		}
   495  		if driverErr.Number != ErrNoResultIDs {
   496  			t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number)
   497  		}
   498  
   499  		data = execResponseData{
   500  			ResultIDs:   "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre",
   501  			ResultTypes: "12544,12544",
   502  		}
   503  		sct.sc.rest = &snowflakeRestful{
   504  			FuncGet:          funcGetQueryRespFail,
   505  			FuncCloseSession: closeSessionMock,
   506  			TokenAccessor:    getSimpleTokenAccessor(),
   507  		}
   508  		_, err = sct.sc.handleMultiExec(context.Background(), data)
   509  		if err == nil {
   510  			t.Fatalf("should have failed")
   511  		}
   512  
   513  		sct.sc.rest.FuncGet = funcGetQueryRespError
   514  		data.SQLState = "01112"
   515  		_, err = sct.sc.handleMultiExec(context.Background(), data)
   516  		if err == nil {
   517  			t.Fatalf("should have failed")
   518  		}
   519  		driverErr, ok = err.(*SnowflakeError)
   520  		if !ok {
   521  			t.Fatalf("should be snowflake error. err: %v", err)
   522  		}
   523  		if driverErr.Number != ErrFailedToPostQuery {
   524  			t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number)
   525  		}
   526  	})
   527  }
   528  
   529  func TestUnitHandleMultiQuery(t *testing.T) {
   530  	runSnowflakeConnTest(t, func(sct *SCTest) {
   531  		data := execResponseData{
   532  			ResultIDs:   "",
   533  			ResultTypes: "",
   534  		}
   535  		rows := new(snowflakeRows)
   536  		err := sct.sc.handleMultiQuery(context.Background(), data, rows)
   537  		if err == nil {
   538  			t.Fatalf("should have failed")
   539  		}
   540  		driverErr, ok := err.(*SnowflakeError)
   541  		if !ok {
   542  			t.Fatalf("should be snowflake error. err: %v", err)
   543  		}
   544  		if driverErr.Number != ErrNoResultIDs {
   545  			t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNoResultIDs, driverErr.Number)
   546  		}
   547  		data = execResponseData{
   548  			ResultIDs:   "1eFhmhe23242kmfd540GgGre,1eFhmhe23242kmfd540GgGre",
   549  			ResultTypes: "12544,12544",
   550  		}
   551  		sct.sc.rest = &snowflakeRestful{
   552  			FuncGet:          funcGetQueryRespFail,
   553  			FuncCloseSession: closeSessionMock,
   554  			TokenAccessor:    getSimpleTokenAccessor(),
   555  		}
   556  		err = sct.sc.handleMultiQuery(context.Background(), data, rows)
   557  		if err == nil {
   558  			t.Fatalf("should have failed")
   559  		}
   560  
   561  		sct.sc.rest.FuncGet = funcGetQueryRespError
   562  		data.SQLState = "01112"
   563  		err = sct.sc.handleMultiQuery(context.Background(), data, rows)
   564  		if err == nil {
   565  			t.Fatalf("should have failed")
   566  		}
   567  		driverErr, ok = err.(*SnowflakeError)
   568  		if !ok {
   569  			t.Fatalf("should be snowflake error. err: %v", err)
   570  		}
   571  		if driverErr.Number != ErrFailedToPostQuery {
   572  			t.Fatalf("unexpected error code. expected: %v, got: %v", ErrFailedToPostQuery, driverErr.Number)
   573  		}
   574  	})
   575  }