github.com/snowflakedb/gosnowflake@v1.9.0/rows_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "context" 7 "database/sql" 8 "database/sql/driver" 9 "fmt" 10 "io" 11 "net/http" 12 "sync" 13 "testing" 14 "time" 15 ) 16 17 type RowsExtended struct { 18 rows *sql.Rows 19 closeChan *chan bool 20 } 21 22 func (rs *RowsExtended) Close() error { 23 *rs.closeChan <- true 24 close(*rs.closeChan) 25 return rs.rows.Close() 26 } 27 28 func (rs *RowsExtended) ColumnTypes() ([]*sql.ColumnType, error) { 29 return rs.rows.ColumnTypes() 30 } 31 32 func (rs *RowsExtended) Columns() ([]string, error) { 33 return rs.rows.Columns() 34 } 35 36 func (rs *RowsExtended) Err() error { 37 return rs.rows.Err() 38 } 39 40 func (rs *RowsExtended) Next() bool { 41 return rs.rows.Next() 42 } 43 44 func (rs *RowsExtended) NextResultSet() bool { 45 return rs.rows.NextResultSet() 46 } 47 48 func (rs *RowsExtended) Scan(dest ...interface{}) error { 49 return rs.rows.Scan(dest...) 50 } 51 52 // test variables 53 var ( 54 rowsInChunk = 123 55 ) 56 57 // Special cases where rows are already closed 58 func TestRowsClose(t *testing.T) { 59 runDBTest(t, func(dbt *DBTest) { 60 rows, err := dbt.query("SELECT 1") 61 if err != nil { 62 dbt.Fatal(err) 63 } 64 if err = rows.Close(); err != nil { 65 dbt.Fatal(err) 66 } 67 68 if rows.Next() { 69 dbt.Fatal("unexpected row after rows.Close()") 70 } 71 if err = rows.Err(); err != nil { 72 dbt.Fatal(err) 73 } 74 }) 75 } 76 77 func TestResultNoRows(t *testing.T) { 78 // DDL 79 runDBTest(t, func(dbt *DBTest) { 80 row, err := dbt.exec("CREATE OR REPLACE TABLE test(c1 int)") 81 if err != nil { 82 t.Fatalf("failed to execute DDL. err: %v", err) 83 } 84 if _, err = row.RowsAffected(); err == nil { 85 t.Fatal("should have failed to get RowsAffected") 86 } 87 if _, err = row.LastInsertId(); err == nil { 88 t.Fatal("should have failed to get LastInsertID") 89 } 90 }) 91 } 92 93 func TestRowsWithoutChunkDownloader(t *testing.T) { 94 sts1 := "1" 95 sts2 := "Test1" 96 var i int 97 cc := make([][]*string, 0) 98 for i = 0; i < 10; i++ { 99 cc = append(cc, []*string{&sts1, &sts2}) 100 } 101 rt := []execResponseRowType{ 102 {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, 103 {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, 104 } 105 cm := []execResponseChunk{} 106 rows := new(snowflakeRows) 107 rows.sc = nil 108 rows.ChunkDownloader = &snowflakeChunkDownloader{ 109 sc: nil, 110 ctx: context.Background(), 111 Total: int64(len(cc)), 112 ChunkMetas: cm, 113 TotalRowIndex: int64(-1), 114 Qrmk: "", 115 FuncDownload: nil, 116 FuncDownloadHelper: nil, 117 RowSet: rowSetType{RowType: rt, JSON: cc}, 118 QueryResultFormat: "json", 119 } 120 rows.ChunkDownloader.start() 121 dest := make([]driver.Value, 2) 122 for i = 0; i < len(cc); i++ { 123 if err := rows.Next(dest); err != nil { 124 t.Fatalf("failed to get value. err: %v", err) 125 } 126 if dest[0] != sts1 { 127 t.Fatalf("failed to get value. expected: %v, got: %v", sts1, dest[0]) 128 } 129 if dest[1] != sts2 { 130 t.Fatalf("failed to get value. expected: %v, got: %v", sts2, dest[1]) 131 } 132 } 133 if err := rows.Next(dest); err != io.EOF { 134 t.Fatalf("failed to finish getting data. err: %v", err) 135 } 136 logger.Infof("dest: %v", dest) 137 138 } 139 140 func downloadChunkTest(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { 141 d := make([][]*string, 0) 142 for i := 0; i < rowsInChunk; i++ { 143 v1 := fmt.Sprintf("%v", idx*1000+i) 144 v2 := fmt.Sprintf("testchunk%v", idx*1000+i) 145 d = append(d, []*string{&v1, &v2}) 146 } 147 scd.ChunksMutex.Lock() 148 scd.Chunks[idx] = make([]chunkRowType, len(d)) 149 populateJSONRowSet(scd.Chunks[idx], d) 150 scd.DoneDownloadCond.Broadcast() 151 scd.ChunksMutex.Unlock() 152 } 153 154 func TestRowsWithChunkDownloader(t *testing.T) { 155 numChunks := 12 156 // changed the workers 157 backupMaxChunkDownloadWorkers := MaxChunkDownloadWorkers 158 MaxChunkDownloadWorkers = 2 159 logger.Info("START TESTS") 160 var i int 161 cc := make([][]*string, 0) 162 for i = 0; i < 100; i++ { 163 v1 := fmt.Sprintf("%v", i) 164 v2 := fmt.Sprintf("Test%v", i) 165 cc = append(cc, []*string{&v1, &v2}) 166 } 167 rt := []execResponseRowType{ 168 {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, 169 {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, 170 } 171 cm := make([]execResponseChunk, 0) 172 for i = 0; i < numChunks; i++ { 173 cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) 174 } 175 rows := new(snowflakeRows) 176 rows.sc = nil 177 rows.ChunkDownloader = &snowflakeChunkDownloader{ 178 sc: nil, 179 ctx: context.Background(), 180 Total: int64(len(cc) + numChunks*rowsInChunk), 181 ChunkMetas: cm, 182 TotalRowIndex: int64(-1), 183 Qrmk: "HAHAHA", 184 FuncDownload: downloadChunkTest, 185 RowSet: rowSetType{RowType: rt, JSON: cc}, 186 } 187 rows.ChunkDownloader.start() 188 cnt := 0 189 dest := make([]driver.Value, 2) 190 var err error 191 for err != io.EOF { 192 err := rows.Next(dest) 193 if err == io.EOF { 194 break 195 } 196 if err != nil { 197 t.Fatalf("failed to get value. err: %v", err) 198 } 199 cnt++ 200 } 201 if cnt != len(cc)+numChunks*rowsInChunk { 202 t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt) 203 } 204 logger.Infof("dest: %v", dest) 205 MaxChunkDownloadWorkers = backupMaxChunkDownloadWorkers 206 logger.Info("END TESTS") 207 } 208 209 func downloadChunkTestError(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { 210 // fail to download 6th and 10th chunk, and retry up to N times and success 211 // NOTE: zero based index 212 scd.ChunksMutex.Lock() 213 defer scd.ChunksMutex.Unlock() 214 if (idx == 6 || idx == 10) && scd.ChunksErrorCounter < maxChunkDownloaderErrorCounter { 215 scd.ChunksError <- &chunkError{ 216 Index: idx, 217 Error: fmt.Errorf( 218 "dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)} 219 scd.DoneDownloadCond.Broadcast() 220 return 221 } 222 d := make([][]*string, 0) 223 for i := 0; i < rowsInChunk; i++ { 224 v1 := fmt.Sprintf("%v", idx*1000+i) 225 v2 := fmt.Sprintf("testchunk%v", idx*1000+i) 226 d = append(d, []*string{&v1, &v2}) 227 } 228 scd.Chunks[idx] = make([]chunkRowType, len(d)) 229 populateJSONRowSet(scd.Chunks[idx], d) 230 scd.DoneDownloadCond.Broadcast() 231 } 232 233 func TestRowsWithChunkDownloaderError(t *testing.T) { 234 numChunks := 12 235 // changed the workers 236 backupMaxChunkDownloadWorkers := MaxChunkDownloadWorkers 237 MaxChunkDownloadWorkers = 3 238 logger.Info("START TESTS") 239 var i int 240 cc := make([][]*string, 0) 241 for i = 0; i < 100; i++ { 242 v1 := fmt.Sprintf("%v", i) 243 v2 := fmt.Sprintf("Test%v", i) 244 cc = append(cc, []*string{&v1, &v2}) 245 } 246 rt := []execResponseRowType{ 247 {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, 248 {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, 249 } 250 cm := make([]execResponseChunk, 0) 251 for i = 0; i < numChunks; i++ { 252 cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) 253 } 254 rows := new(snowflakeRows) 255 rows.sc = nil 256 rows.ChunkDownloader = &snowflakeChunkDownloader{ 257 sc: nil, 258 ctx: context.Background(), 259 Total: int64(len(cc) + numChunks*rowsInChunk), 260 ChunkMetas: cm, 261 TotalRowIndex: int64(-1), 262 Qrmk: "HOHOHO", 263 FuncDownload: downloadChunkTestError, 264 RowSet: rowSetType{RowType: rt, JSON: cc}, 265 } 266 rows.ChunkDownloader.start() 267 cnt := 0 268 dest := make([]driver.Value, 2) 269 var err error 270 for err != io.EOF { 271 err := rows.Next(dest) 272 if err == io.EOF { 273 break 274 } 275 if err != nil { 276 t.Fatalf("failed to get value. err: %v", err) 277 } 278 // fmt.Printf("data: %v\n", dest) 279 cnt++ 280 } 281 if cnt != len(cc)+numChunks*rowsInChunk { 282 t.Fatalf("failed to get all results. expected:%v, got:%v", len(cc)+numChunks*rowsInChunk, cnt) 283 } 284 logger.Infof("dest: %v", dest) 285 MaxChunkDownloadWorkers = backupMaxChunkDownloadWorkers 286 logger.Info("END TESTS") 287 } 288 289 func downloadChunkTestErrorFail(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { 290 // fail to download 6th and 10th chunk, and retry up to N times and fail 291 // NOTE: zero based index 292 scd.ChunksMutex.Lock() 293 defer scd.ChunksMutex.Unlock() 294 if idx == 6 && scd.ChunksErrorCounter <= maxChunkDownloaderErrorCounter { 295 scd.ChunksError <- &chunkError{ 296 Index: idx, 297 Error: fmt.Errorf( 298 "dummy error. idx: %v, errCnt: %v", idx+1, scd.ChunksErrorCounter)} 299 scd.DoneDownloadCond.Broadcast() 300 return 301 } 302 d := make([][]*string, 0) 303 for i := 0; i < rowsInChunk; i++ { 304 v1 := fmt.Sprintf("%v", idx*1000+i) 305 v2 := fmt.Sprintf("testchunk%v", idx*1000+i) 306 d = append(d, []*string{&v1, &v2}) 307 } 308 scd.Chunks[idx] = make([]chunkRowType, len(d)) 309 populateJSONRowSet(scd.Chunks[idx], d) 310 scd.DoneDownloadCond.Broadcast() 311 } 312 313 func TestRowsWithChunkDownloaderErrorFail(t *testing.T) { 314 numChunks := 12 315 // changed the workers 316 logger.Info("START TESTS") 317 var i int 318 cc := make([][]*string, 0) 319 for i = 0; i < 100; i++ { 320 v1 := fmt.Sprintf("%v", i) 321 v2 := fmt.Sprintf("Test%v", i) 322 cc = append(cc, []*string{&v1, &v2}) 323 } 324 rt := []execResponseRowType{ 325 {Name: "c1", ByteLength: 10, Length: 10, Type: "FIXED", Scale: 0, Nullable: true}, 326 {Name: "c2", ByteLength: 100000, Length: 100000, Type: "TEXT", Scale: 0, Nullable: false}, 327 } 328 cm := make([]execResponseChunk, 0) 329 for i = 0; i < numChunks; i++ { 330 cm = append(cm, execResponseChunk{URL: fmt.Sprintf("dummyURL%v", i+1), RowCount: rowsInChunk}) 331 } 332 rows := new(snowflakeRows) 333 rows.sc = nil 334 rows.ChunkDownloader = &snowflakeChunkDownloader{ 335 sc: nil, 336 ctx: context.Background(), 337 Total: int64(len(cc) + numChunks*rowsInChunk), 338 ChunkMetas: cm, 339 TotalRowIndex: int64(-1), 340 Qrmk: "HOHOHO", 341 FuncDownload: downloadChunkTestErrorFail, 342 RowSet: rowSetType{RowType: rt, JSON: cc}, 343 } 344 rows.ChunkDownloader.start() 345 cnt := 0 346 dest := make([]driver.Value, 2) 347 var err error 348 for err != io.EOF { 349 err := rows.Next(dest) 350 if err == io.EOF { 351 break 352 } 353 if err != nil { 354 logger.Infof( 355 "failure was expected by the number of rows is wrong. expected: %v, got: %v", 715, cnt) 356 break 357 } 358 cnt++ 359 } 360 } 361 362 func getChunkTestInvalidResponseBody(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) ( 363 *http.Response, error) { 364 return &http.Response{ 365 StatusCode: http.StatusOK, 366 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 367 }, nil 368 } 369 370 func TestDownloadChunkInvalidResponseBody(t *testing.T) { 371 numChunks := 2 372 cm := make([]execResponseChunk, 0) 373 for i := 0; i < numChunks; i++ { 374 cm = append(cm, execResponseChunk{URL: fmt.Sprintf( 375 "dummyURL%v", i+1), RowCount: rowsInChunk}) 376 } 377 scd := &snowflakeChunkDownloader{ 378 sc: &snowflakeConn{ 379 rest: &snowflakeRestful{RequestTimeout: defaultRequestTimeout}, 380 }, 381 ctx: context.Background(), 382 ChunkMetas: cm, 383 TotalRowIndex: int64(-1), 384 Qrmk: "HOHOHO", 385 FuncDownload: downloadChunk, 386 FuncDownloadHelper: downloadChunkHelper, 387 FuncGet: getChunkTestInvalidResponseBody, 388 } 389 scd.ChunksMutex = &sync.Mutex{} 390 scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) 391 scd.Chunks = make(map[int][]chunkRowType) 392 scd.ChunksError = make(chan *chunkError, 1) 393 scd.FuncDownload(scd.ctx, scd, 1) 394 select { 395 case errc := <-scd.ChunksError: 396 if errc.Index != 1 { 397 t.Fatalf("the error should have caused with chunk idx: %v", errc.Index) 398 } 399 default: 400 t.Fatal("should have caused an error and queued in scd.ChunksError") 401 } 402 } 403 404 func getChunkTestErrorStatus(_ context.Context, _ *snowflakeConn, _ string, _ map[string]string, _ time.Duration) ( 405 *http.Response, error) { 406 return &http.Response{ 407 StatusCode: http.StatusBadGateway, 408 Body: &fakeResponseBody{body: []byte{0x12, 0x34}}, 409 }, nil 410 } 411 412 func TestDownloadChunkErrorStatus(t *testing.T) { 413 numChunks := 2 414 cm := make([]execResponseChunk, 0) 415 for i := 0; i < numChunks; i++ { 416 cm = append(cm, execResponseChunk{URL: fmt.Sprintf( 417 "dummyURL%v", i+1), RowCount: rowsInChunk}) 418 } 419 scd := &snowflakeChunkDownloader{ 420 sc: &snowflakeConn{ 421 rest: &snowflakeRestful{RequestTimeout: defaultRequestTimeout}, 422 }, 423 ctx: context.Background(), 424 ChunkMetas: cm, 425 TotalRowIndex: int64(-1), 426 Qrmk: "HOHOHO", 427 FuncDownload: downloadChunk, 428 FuncDownloadHelper: downloadChunkHelper, 429 FuncGet: getChunkTestErrorStatus, 430 } 431 scd.ChunksMutex = &sync.Mutex{} 432 scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) 433 scd.Chunks = make(map[int][]chunkRowType) 434 scd.ChunksError = make(chan *chunkError, 1) 435 scd.FuncDownload(scd.ctx, scd, 1) 436 select { 437 case errc := <-scd.ChunksError: 438 if errc.Index != 1 { 439 t.Fatalf("the error should have caused with chunk idx: %v", errc.Index) 440 } 441 serr, ok := errc.Error.(*SnowflakeError) 442 if !ok { 443 t.Fatalf("should have been snowflake error. err: %v", errc.Error) 444 } 445 if serr.Number != ErrFailedToGetChunk { 446 t.Fatalf("message error code is not correct. msg: %v", serr.Number) 447 } 448 default: 449 t.Fatal("should have caused an error and queued in scd.ChunksError") 450 } 451 } 452 453 func TestWithArrowBatchesNotImplementedForResult(t *testing.T) { 454 ctx := WithArrowBatches(context.Background()) 455 runSnowflakeConnTest(t, func(sct *SCTest) { 456 457 sct.mustExec("create or replace table testArrowBatches (a int, b int)", nil) 458 defer sct.sc.Exec("drop table if exists testArrowBatches", nil) 459 460 result := sct.mustExecContext(ctx, "insert into testArrowBatches values (1, 2), (3, 4), (5, 6)", []driver.NamedValue{}) 461 462 _, err := result.(*snowflakeResult).GetArrowBatches() 463 if err == nil { 464 t.Fatal("should have raised an error") 465 } 466 driverErr, ok := err.(*SnowflakeError) 467 if !ok { 468 t.Fatalf("should be snowflake error. err: %v", err) 469 } 470 if driverErr.Number != ErrNotImplemented { 471 t.Fatalf("unexpected error code. expected: %v, got: %v", ErrNotImplemented, driverErr.Number) 472 } 473 }) 474 } 475 476 func TestLocationChangesAfterAlterSession(t *testing.T) { 477 runDBTest(t, func(dbt *DBTest) { 478 dbt.mustExec("CREATE OR REPLACE TABLE location_timestamp_ltz (val timestamp_ltz)") 479 defer dbt.mustExec("DROP TABLE location_timestamp_ltz") 480 dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Europe/Warsaw'") 481 dbt.mustExec("INSERT INTO location_timestamp_ltz VALUES('2023-08-09 10:00:00')") 482 rows1 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") 483 defer rows1.Close() 484 if !rows1.Next() { 485 t.Fatalf("cannot read a record") 486 } 487 var t1 time.Time 488 rows1.Scan(&t1) 489 if t1.Location().String() != "Europe/Warsaw" { 490 t.Fatalf("should return time in Warsaw timezone") 491 } 492 dbt.mustExec("ALTER SESSION SET TIMEZONE = 'Pacific/Honolulu'") 493 rows2 := dbt.mustQuery("SELECT * FROM location_timestamp_ltz") 494 defer rows2.Close() 495 if !rows2.Next() { 496 t.Fatalf("cannot read a record") 497 } 498 var t2 time.Time 499 rows2.Scan(&t2) 500 if t2.Location().String() != "Pacific/Honolulu" { 501 t.Fatalf("should return time in Honolulu timezone") 502 } 503 }) 504 }