github.com/snowflakedb/gosnowflake@v1.9.0/rows_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"database/sql/driver"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  )
    16  
    17  type RowsExtended struct {
    18  	rows      *sql.Rows
    19  	closeChan *chan bool
    20  }
    21  
    22  func (rs *RowsExtended) Close() error {
    23  	*rs.closeChan <- true
    24  	close(*rs.closeChan)
    25  	return rs.rows.Close()
    26  }
    27  
    28  func (rs *RowsExtended) ColumnTypes() ([]*sql.ColumnType, error) {
    29  	return rs.rows.ColumnTypes()
    30  }
    31  
    32  func (rs *RowsExtended) Columns() ([]string, error) {
    33  	return rs.rows.Columns()
    34  }
    35  
    36  func (rs *RowsExtended) Err() error {
    37  	return rs.rows.Err()
    38  }
    39  
    40  func (rs *RowsExtended) Next() bool {
    41  	return rs.rows.Next()
    42  }
    43  
    44  func (rs *RowsExtended) NextResultSet() bool {
    45  	return rs.rows.NextResultSet()
    46  }
    47  
    48  func (rs *RowsExtended) Scan(dest ...interface{}) error {
    49  	return rs.rows.Scan(dest...)
    50  }
    51  
    52  // test variables
    53  var (
    54  	rowsInChunk = 123
    55  )
    56  
    57  // Special cases where rows are already closed
    58  func TestRowsClose(t *testing.T) {
    59  	runDBTest(t, func(dbt *DBTest) {
    60  		rows, err := dbt.query("SELECT 1")
    61  		if err != nil {
    62  			dbt.Fatal(err)
    63  		}
    64  		if err = rows.Close(); err != nil {
    65  			dbt.Fatal(err)
    66  		}
    67  
    68  		if rows.Next() {
    69  			dbt.Fatal("unexpected row after rows.Close()")
    70  		}
    71  		if err = rows.Err(); err != nil {
    72  			dbt.Fatal(err)
    73  		}
    74  	})
    75  }
    76  
    77  func TestResultNoRows(t *testing.T) {
    78  	// DDL
    79  	runDBTest(t, func(dbt *DBTest) {
    80  		row, err := dbt.exec("CREATE OR REPLACE TABLE test(c1 int)")
    81  		if err != nil {
    82  			t.Fatalf("failed to execute DDL. err: %v", err)
    83  		}
    84  		if _, err = row.RowsAffected(); err == nil {
    85  			t.Fatal("should have failed to get RowsAffected")
    86  		}
    87  		if _, err = row.LastInsertId(); err == nil {
    88  			t.Fatal("should have failed to get LastInsertID")
    89  		}
    90  	})
    91  }
    92  
    93  func TestRowsWithoutChunkDownloader(t *testing.T) {
    94  	sts1 := "1"
    95  	sts2 := "Test1"
    96  	var i int
    97  	cc := make([][]*string, 0)
    98  	for i = 0; i < 10; i++ {
    99  		cc = append(cc, []*string{&sts1, &sts2})
   100  	}
   101  	rt := []execResponseRowType{
   102  		{Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true},
   103  		{Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false},
   104  	}
   105  	cm := []execResponseChunk{}
   106  	rows := new(snowflakeRows)
   107  	rows.sc = nil
   108  	rows.ChunkDownloader = &snowflakeChunkDownloader{
   109  		sc:                 nil,
   110  		ctx:                context.Background(),
   111  		Total:              int64(len(cc)),
   112  		ChunkMetas:         cm,
   113  		TotalRowIndex:      int64(-1),
   114  		Qrmk:               "",
   115  		FuncDownload:       nil,
   116  		FuncDownloadHelper: nil,
   117  		RowSet:             rowSetType{RowType: rt, JSON: cc},
   118  		QueryResultFormat:  "json",
   119  	}
   120  	rows.ChunkDownloader.start()
   121  	dest := make([]driver.Value, 2)
   122  	for i = 0; i < len(cc); i++ {
   123  		if err := rows.Next(dest); err != nil {
   124  			t.Fatalf("failed to get value. err: %v", err)
   125  		}
   126  		if dest[0] != sts1 {
   127  			t.Fatalf("failed to get value. expected: %v, got: %v", sts1, dest[0])
   128  		}
   129  		if dest[1] != sts2 {
   130  			t.Fatalf("failed to get value. expected: %v, got: %v", sts2, dest[1])
   131  		}
   132  	}
   133  	if err := rows.Next(dest); err != io.EOF {
   134  		t.Fatalf("failed to finish getting data. err: %v", err)
   135  	}
   136  	logger.Infof("dest: %v", dest)
   137  
   138  }
   139  
   140  func downloadChunkTest(ctx context.Context, scd *snowflakeChunkDownloader, idx int) {
   141  	d := make([][]*string, 0)
   142  	for i := 0; i < rowsInChunk; i++ {
   143  		v1 := fmt.Sprintf("%v", idx*1000+i)
   144  		v2 := fmt.Sprintf("testchunk%v", idx*1000+i)
   145  		d = append(d, []*string{&v1, &v2})
   146  	}
   147  	scd.ChunksMutex.Lock()
   148  	scd.Chunks[idx] = make([]chunkRowType, len(d))
   149  	populateJSONRowSet(scd.Chunks[idx], d)
   150  	scd.DoneDownloadCond.Broadcast()
   151  	scd.ChunksMutex.Unlock()
   152  }
   153  
   154  func TestRowsWithChunkDownloader(t *testing.T) {
   155  	numChunks := 12
   156  	// changed the workers
   157  	backupMaxChunkDownloadWorkers := MaxChunkDownloadWorkers
   158  	MaxChunkDownloadWorkers = 2
   159  	logger.Info("START TESTS")
   160  	var i int
   161  	cc := make([][]*string, 0)
   162  	for i = 0; i < 100; i++ {
   163  		v1 := fmt.Sprintf("%v", i)
   164  		v2 := fmt.Sprintf("Test%v", i)
   165  		cc = append(cc, []*string{&v1, &v2})
   166  	}
   167  	rt := []execResponseRowType{
   168  		{Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true},
   169  		{Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false},
   170  	}
   171  	cm := make([]execResponseChunk, 0)
   172  	for i = 0; i < numChunks; i++ {
   173  		cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk})
   174  	}
   175  	rows := new(snowflakeRows)
   176  	rows.sc = nil
   177  	rows.ChunkDownloader = &snowflakeChunkDownloader{
   178  		sc:            nil,
   179  		ctx:           context.Background(),
   180  		Total:         int64(len(cc) + numChunks*rowsInChunk),
   181  		ChunkMetas:    cm,
   182  		TotalRowIndex: int64(-1),
   183  		Qrmk:          "HAHAHA",
   184  		FuncDownload:  downloadChunkTest,
   185  		RowSet:        rowSetType{RowType: rt, JSON: cc},
   186  	}
   187  	rows.ChunkDownloader.start()
   188  	cnt := 0
   189  	dest := make([]driver.Value, 2)
   190  	var err error
   191  	for err != io.EOF {
   192  		err := rows.Next(dest)
   193  		if err == io.EOF {
   194  			break
   195  		}
   196  		if err != nil {
   197  			t.Fatalf("failed to get value. err: %v", err)
   198  		}
   199  		cnt++
   200  	}
   201  	if cnt != len(cc)+numChunks*rowsInChunk {
   202  		t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt)
   203  	}
   204  	logger.Infof("dest: %v", dest)
   205  	MaxChunkDownloadWorkers = backupMaxChunkDownloadWorkers
   206  	logger.Info("END TESTS")
   207  }
   208  
   209  func downloadChunkTestError(ctx context.Context, scd *snowflakeChunkDownloader, idx int) {
   210  	// fail to download 6th and 10th chunk, and retry up to N times and success
   211  	// NOTE: zero based index
   212  	scd.ChunksMutex.Lock()
   213  	defer scd.ChunksMutex.Unlock()
   214  	if (idx == 6 || idx == 10) && scd.ChunksErrorCounter < maxChunkDownloaderErrorCounter {
   215  		scd.ChunksError <- &chunkError{
   216  			Index: idx,
   217  			Error: fmt.Errorf(
   218  				"dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)}
   219  		scd.DoneDownloadCond.Broadcast()
   220  		return
   221  	}
   222  	d := make([][]*string, 0)
   223  	for i := 0; i < rowsInChunk; i++ {
   224  		v1 := fmt.Sprintf("%v", idx*1000+i)
   225  		v2 := fmt.Sprintf("testchunk%v", idx*1000+i)
   226  		d = append(d, []*string{&v1, &v2})
   227  	}
   228  	scd.Chunks[idx] = make([]chunkRowType, len(d))
   229  	populateJSONRowSet(scd.Chunks[idx], d)
   230  	scd.DoneDownloadCond.Broadcast()
   231  }
   232  
   233  func TestRowsWithChunkDownloaderError(t *testing.T) {
   234  	numChunks := 12
   235  	// changed the workers
   236  	backupMaxChunkDownloadWorkers := MaxChunkDownloadWorkers
   237  	MaxChunkDownloadWorkers = 3
   238  	logger.Info("START TESTS")
   239  	var i int
   240  	cc := make([][]*string, 0)
   241  	for i = 0; i < 100; i++ {
   242  		v1 := fmt.Sprintf("%v", i)
   243  		v2 := fmt.Sprintf("Test%v", i)
   244  		cc = append(cc, []*string{&v1, &v2})
   245  	}
   246  	rt := []execResponseRowType{
   247  		{Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true},
   248  		{Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false},
   249  	}
   250  	cm := make([]execResponseChunk, 0)
   251  	for i = 0; i < numChunks; i++ {
   252  		cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk})
   253  	}
   254  	rows := new(snowflakeRows)
   255  	rows.sc = nil
   256  	rows.ChunkDownloader = &snowflakeChunkDownloader{
   257  		sc:            nil,
   258  		ctx:           context.Background(),
   259  		Total:         int64(len(cc) + numChunks*rowsInChunk),
   260  		ChunkMetas:    cm,
   261  		TotalRowIndex: int64(-1),
   262  		Qrmk:          "HOHOHO",
   263  		FuncDownload:  downloadChunkTestError,
   264  		RowSet:        rowSetType{RowType: rt, JSON: cc},
   265  	}
   266  	rows.ChunkDownloader.start()
   267  	cnt := 0
   268  	dest := make([]driver.Value, 2)
   269  	var err error
   270  	for err != io.EOF {
   271  		err := rows.Next(dest)
   272  		if err == io.EOF {
   273  			break
   274  		}
   275  		if err != nil {
   276  			t.Fatalf("failed to get value. err: %v", err)
   277  		}
   278  		// fmt.Printf("data: %v\n", dest)
   279  		cnt++
   280  	}
   281  	if cnt != len(cc)+numChunks*rowsInChunk {
   282  		t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt)
   283  	}
   284  	logger.Infof("dest: %v", dest)
   285  	MaxChunkDownloadWorkers = backupMaxChunkDownloadWorkers
   286  	logger.Info("END TESTS")
   287  }
   288  
   289  func downloadChunkTestErrorFail(ctx context.Context, scd *snowflakeChunkDownloader, idx int) {
   290  	// fail to download 6th and 10th chunk, and retry up to N times and fail
   291  	// NOTE: zero based index
   292  	scd.ChunksMutex.Lock()
   293  	defer scd.ChunksMutex.Unlock()
   294  	if idx == 6 && scd.ChunksErrorCounter <= maxChunkDownloaderErrorCounter {
   295  		scd.ChunksError <- &chunkError{
   296  			Index: idx,
   297  			Error: fmt.Errorf(
   298  				"dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)}
   299  		scd.DoneDownloadCond.Broadcast()
   300  		return
   301  	}
   302  	d := make([][]*string, 0)
   303  	for i := 0; i < rowsInChunk; i++ {
   304  		v1 := fmt.Sprintf("%v", idx*1000+i)
   305  		v2 := fmt.Sprintf("testchunk%v", idx*1000+i)
   306  		d = append(d, []*string{&v1, &v2})
   307  	}
   308  	scd.Chunks[idx] = make([]chunkRowType, len(d))
   309  	populateJSONRowSet(scd.Chunks[idx], d)
   310  	scd.DoneDownloadCond.Broadcast()
   311  }
   312  
   313  func TestRowsWithChunkDownloaderErrorFail(t *testing.T) {
   314  	numChunks := 12
   315  	// changed the workers
   316  	logger.Info("START TESTS")
   317  	var i int
   318  	cc := make([][]*string, 0)
   319  	for i = 0; i < 100; i++ {
   320  		v1 := fmt.Sprintf("%v", i)
   321  		v2 := fmt.Sprintf("Test%v", i)
   322  		cc = append(cc, []*string{&v1, &v2})
   323  	}
   324  	rt := []execResponseRowType{
   325  		{Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true},
   326  		{Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false},
   327  	}
   328  	cm := make([]execResponseChunk, 0)
   329  	for i = 0; i < numChunks; i++ {
   330  		cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk})
   331  	}
   332  	rows := new(snowflakeRows)
   333  	rows.sc = nil
   334  	rows.ChunkDownloader = &snowflakeChunkDownloader{
   335  		sc:            nil,
   336  		ctx:           context.Background(),
   337  		Total:         int64(len(cc) + numChunks*rowsInChunk),
   338  		ChunkMetas:    cm,
   339  		TotalRowIndex: int64(-1),
   340  		Qrmk:          "HOHOHO",
   341  		FuncDownload:  downloadChunkTestErrorFail,
   342  		RowSet:        rowSetType{RowType: rt, JSON: cc},
   343  	}
   344  	rows.ChunkDownloader.start()
   345  	cnt := 0
   346  	dest := make([]driver.Value, 2)
   347  	var err error
   348  	for err != io.EOF {
   349  		err := rows.Next(dest)
   350  		if err == io.EOF {
   351  			break
   352  		}
   353  		if err != nil {
   354  			logger.Infof(
   355  				"failure was expected by the number of rows is wrong. expected: %v, got: %v", 715, cnt)
   356  			break
   357  		}
   358  		cnt++
   359  	}
   360  }
   361  
   362  func getChunkTestInvalidResponseBody(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (
   363  	*http.Response, error) {
   364  	return &http.Response{
   365  		StatusCode: http.StatusOK,
   366  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
   367  	}, nil
   368  }
   369  
   370  func TestDownloadChunkInvalidResponseBody(t *testing.T) {
   371  	numChunks := 2
   372  	cm := make([]execResponseChunk, 0)
   373  	for i := 0; i < numChunks; i++ {
   374  		cm = append(cm, execResponseChunk{URL: fmt.Sprintf(
   375  			"dummyURL%v", i+1), RowCount: rowsInChunk})
   376  	}
   377  	scd := &snowflakeChunkDownloader{
   378  		sc: &snowflakeConn{
   379  			rest: &snowflakeRestful{RequestTimeout: defaultRequestTimeout},
   380  		},
   381  		ctx:                context.Background(),
   382  		ChunkMetas:         cm,
   383  		TotalRowIndex:      int64(-1),
   384  		Qrmk:               "HOHOHO",
   385  		FuncDownload:       downloadChunk,
   386  		FuncDownloadHelper: downloadChunkHelper,
   387  		FuncGet:            getChunkTestInvalidResponseBody,
   388  	}
   389  	scd.ChunksMutex = &sync.Mutex{}
   390  	scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex)
   391  	scd.Chunks = make(map[int][]chunkRowType)
   392  	scd.ChunksError = make(chan *chunkError, 1)
   393  	scd.FuncDownload(scd.ctx, scd, 1)
   394  	select {
   395  	case errc := <-scd.ChunksError:
   396  		if errc.Index != 1 {
   397  			t.Fatalf("the error should have caused with chunk idx: %v", errc.Index)
   398  		}
   399  	default:
   400  		t.Fatal("should have caused an error and queued in scd.ChunksError")
   401  	}
   402  }
   403  
   404  func getChunkTestErrorStatus(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) (
   405  	*http.Response, error) {
   406  	return &http.Response{
   407  		StatusCode: http.StatusBadGateway,
   408  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
   409  	}, nil
   410  }
   411  
   412  func TestDownloadChunkErrorStatus(t *testing.T) {
   413  	numChunks := 2
   414  	cm := make([]execResponseChunk, 0)
   415  	for i := 0; i < numChunks; i++ {
   416  		cm = append(cm, execResponseChunk{URL: fmt.Sprintf(
   417  			"dummyURL%v", i+1), RowCount: rowsInChunk})
   418  	}
   419  	scd := &snowflakeChunkDownloader{
   420  		sc: &snowflakeConn{
   421  			rest: &snowflakeRestful{RequestTimeout: defaultRequestTimeout},
   422  		},
   423  		ctx:                context.Background(),
   424  		ChunkMetas:         cm,
   425  		TotalRowIndex:      int64(-1),
   426  		Qrmk:               "HOHOHO",
   427  		FuncDownload:       downloadChunk,
   428  		FuncDownloadHelper: downloadChunkHelper,
   429  		FuncGet:            getChunkTestErrorStatus,
   430  	}
   431  	scd.ChunksMutex = &sync.Mutex{}
   432  	scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex)
   433  	scd.Chunks = make(map[int][]chunkRowType)
   434  	scd.ChunksError = make(chan *chunkError, 1)
   435  	scd.FuncDownload(scd.ctx, scd, 1)
   436  	select {
   437  	case errc := <-scd.ChunksError:
   438  		if errc.Index != 1 {
   439  			t.Fatalf("the error should have caused with chunk idx: %v", errc.Index)
   440  		}
   441  		serr, ok := errc.Error.(*SnowflakeError)
   442  		if !ok {
   443  			t.Fatalf("should have been snowflake error. err: %v", errc.Error)
   444  		}
   445  		if serr.Number != ErrFailedToGetChunk {
   446  			t.Fatalf("message error code is not correct. msg: %v", serr.Number)
   447  		}
   448  	default:
   449  		t.Fatal("should have caused an error and queued in scd.ChunksError")
   450  	}
   451  }
   452  
   453  func TestWithArrowBatchesNotImplementedForResult(t *testing.T) {
   454  	ctx := WithArrowBatches(context.Background())
   455  	runSnowflakeConnTest(t, func(sct *SCTest) {
   456  
   457  		sct.mustExec("create or replace table testArrowBatches (a int, b int)", nil)
   458  		defer sct.sc.Exec("drop table if exists testArrowBatches", nil)
   459  
   460  		result := sct.mustExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{})
   461  
   462  		_, err := result.(*snowflakeResult).GetArrowBatches()
   463  		if err == nil {
   464  			t.Fatal("should have raised an error")
   465  		}
   466  		driverErr, ok := err.(*SnowflakeError)
   467  		if !ok {
   468  			t.Fatalf("should be snowflake error. err: %v", err)
   469  		}
   470  		if driverErr.Number != ErrNotImplemented {
   471  			t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number)
   472  		}
   473  	})
   474  }
   475  
   476  func TestLocationChangesAfterAlterSession(t *testing.T) {
   477  	runDBTest(t, func(dbt *DBTest) {
   478  		dbt.mustExec("CREATE OR REPLACE TABLE location_timestamp_ltz (val timestamp_ltz)")
   479  		defer dbt.mustExec("DROP TABLE location_timestamp_ltz")
   480  		dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'")
   481  		dbt.mustExec("INSERT INTO location_timestamp_ltz VALUES('2023-08-09 10:00:00')")
   482  		rows1 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz")
   483  		defer rows1.Close()
   484  		if !rows1.Next() {
   485  			t.Fatalf("cannot read a record")
   486  		}
   487  		var t1 time.Time
   488  		rows1.Scan(&t1)
   489  		if t1.Location().String() != "Europe/Warsaw" {
   490  			t.Fatalf("should return time in Warsaw timezone")
   491  		}
   492  		dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Pacific/Honolulu'")
   493  		rows2 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz")
   494  		defer rows2.Close()
   495  		if !rows2.Next() {
   496  			t.Fatalf("cannot read a record")
   497  		}
   498  		var t2 time.Time
   499  		rows2.Scan(&t2)
   500  		if t2.Location().String() != "Pacific/Honolulu" {
   501  			t.Fatalf("should return time in Honolulu timezone")
   502  		}
   503  	})
   504  }