github.com/snowflakedb/gosnowflake@v1.9.0/chunk_test.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bytes" 7 "context" 8 "database/sql/driver" 9 "encoding/json" 10 "fmt" 11 "io" 12 "math/rand" 13 "strings" 14 "sync" 15 "sync/atomic" 16 "testing" 17 18 "github.com/apache/arrow/go/v15/arrow/ipc" 19 "github.com/apache/arrow/go/v15/arrow/memory" 20 ) 21 22 func TestBadChunkData(t *testing.T) { 23 testDecodeErr(t, "") 24 testDecodeErr(t, "null") 25 testDecodeErr(t, "42") 26 testDecodeErr(t, "\"null\"") 27 testDecodeErr(t, "{}") 28 29 testDecodeErr(t, "[[]") 30 testDecodeErr(t, "[null]") 31 testDecodeErr(t, `[[hello world]]`) 32 33 testDecodeErr(t, `[[""hello world""]]`) 34 testDecodeErr(t, `[["\"hello world""]]`) 35 testDecodeErr(t, `[[""hello world\""]]`) 36 testDecodeErr(t, `[["hello world`) 37 testDecodeErr(t, `[["hello world"`) 38 testDecodeErr(t, `[["hello world"]`) 39 40 testDecodeErr(t, `[["\uQQQQ"]]`) 41 42 for b := byte(0); b < ' '; b++ { 43 testDecodeErr(t, string([]byte{ 44 '[', '[', '"', b, '"', ']', ']', 45 })) 46 } 47 } 48 49 func TestValidChunkData(t *testing.T) { 50 testDecodeOk(t, "[]") 51 testDecodeOk(t, "[ ]") 52 testDecodeOk(t, "[[]]") 53 testDecodeOk(t, "[ [ ] ]") 54 testDecodeOk(t, "[[],[],[],[]]") 55 testDecodeOk(t, "[[] , [] , [], [] ]") 56 57 testDecodeOk(t, "[[null]]") 58 testDecodeOk(t, "[[\n\t\r null]]") 59 testDecodeOk(t, "[[null,null]]") 60 testDecodeOk(t, "[[ null , null ]]") 61 testDecodeOk(t, "[[null],[null],[null]]") 62 testDecodeOk(t, "[[null],[ null ] , [null]]") 63 64 testDecodeOk(t, `[[""]]`) 65 testDecodeOk(t, `[["false"]]`) 66 testDecodeOk(t, `[["true"]]`) 67 testDecodeOk(t, `[["42"]]`) 68 69 testDecodeOk(t, `[[""]]`) 70 testDecodeOk(t, `[["hello"]]`) 71 testDecodeOk(t, `[["hello world"]]`) 72 73 testDecodeOk(t, `[["/ ' \\ \b \t \n \f \r \""]]`) 74 testDecodeOk(t, `[["❄"]]`) 75 testDecodeOk(t, `[["\u2744"]]`) 76 testDecodeOk(t, `[["\uFfFc"]]`) // consume replacement chars 77 testDecodeOk(t, `[["\ufffd"]]`) // consume replacement chars 78 testDecodeOk(t, `[["\u0000"]]`) // yes, this is valid 79 testDecodeOk(t, `[["\uD834\uDD1E"]]`) // surrogate pair 80 testDecodeOk(t, `[["\uD834\u0000"]]`) // corrupt surrogate pair 81 82 testDecodeOk(t, `[["$"]]`) // "$" 83 testDecodeOk(t, `[["\u0024"]]`) // "$" 84 85 testDecodeOk(t, `[["\uC2A2"]]`) // "¢" 86 testDecodeOk(t, `[["¢"]]`) // "¢" 87 88 testDecodeOk(t, `[["\u00E2\u82AC"]]`) // "€" 89 testDecodeOk(t, `[["€"]]`) // "€" 90 91 testDecodeOk(t, `[["\uF090\u8D88"]]`) // "𐍈" 92 testDecodeOk(t, `[["𐍈"]]`) // "𐍈" 93 } 94 95 func TestSmallBufferChunkData(t *testing.T) { 96 r := strings.NewReader(`[ 97 [null,"hello world"], 98 ["foo bar", null], 99 [null, null] , 100 ["foo bar", "hello world" ] 101 ]`) 102 103 lcd := largeChunkDecoder{ 104 r, 0, 0, 105 0, 0, 106 make([]byte, 1), 107 bytes.NewBuffer(make([]byte, defaultStringBufferSize)), 108 nil, 109 } 110 111 if _, err := lcd.decode(); err != nil { 112 t.Fatalf("failed with small buffer: %s", err) 113 } 114 } 115 116 func TestEnsureBytes(t *testing.T) { 117 // the content here doesn't matter 118 r := strings.NewReader("0123456789") 119 120 lcd := largeChunkDecoder{ 121 r, 0, 0, 122 3, 8189, 123 make([]byte, 8192), 124 bytes.NewBuffer(make([]byte, defaultStringBufferSize)), 125 nil, 126 } 127 128 lcd.ensureBytes(4) 129 130 // we expect the new remainder to be 3 + 10 (length of r) 131 if lcd.rem != 13 { 132 t.Fatalf("buffer was not refilled correctly") 133 } 134 } 135 136 func testDecodeOk(t *testing.T, s string) { 137 var rows [][]*string 138 if err := json.Unmarshal([]byte(s), &rows); err != nil { 139 t.Fatalf("test case is not valid json / [][]*string: %s", s) 140 } 141 142 // NOTE we parse and stringify the expected result to 143 // remove superficial differences, like whitespace 144 expect, err := json.Marshal(rows) 145 if err != nil { 146 t.Fatalf("unreachable: %s", err) 147 } 148 149 rows, err = decodeLargeChunk(strings.NewReader(s), 0, 0) 150 if err != nil { 151 t.Fatalf("expected decode to succeed: %s", err) 152 } 153 154 actual, err := json.Marshal(rows) 155 if err != nil { 156 t.Fatalf("json marshal failed: %s", err) 157 } 158 if string(actual) != string(expect) { 159 t.Fatalf(` 160 result did not match expected result 161 expect=%s 162 bytes=(%v) 163 164 acutal=%s 165 bytes=(%v)`, 166 string(expect), expect, 167 string(actual), actual, 168 ) 169 } 170 } 171 172 func testDecodeErr(t *testing.T, s string) { 173 if _, err := decodeLargeChunk(strings.NewReader(s), 0, 0); err == nil { 174 t.Fatalf("expected decode to fail for input: %s", s) 175 } 176 } 177 178 type mockStreamChunkFetcher struct { 179 chunks map[string][][]*string 180 } 181 182 func (f *mockStreamChunkFetcher) fetch(url string, stream chan<- []*string) error { 183 for _, row := range f.chunks[url] { 184 stream <- row 185 } 186 return nil 187 } 188 189 func TestStreamChunkDownloaderFirstRows(t *testing.T) { 190 fetcher := &mockStreamChunkFetcher{} 191 firstRows := generateStreamChunkRows(10, 4) 192 downloader := newStreamChunkDownloader( 193 context.Background(), 194 fetcher, 195 int64(len(firstRows)), 196 []execResponseRowType{}, 197 firstRows, 198 []execResponseChunk{}) 199 if err := downloader.start(); err != nil { 200 t.Fatalf("chunk download start failed. err: %v", err) 201 } 202 for i := 0; i < len(firstRows); i++ { 203 if !downloader.hasNextResultSet() { 204 t.Error("failed to retrieve next result set") 205 } 206 if err := downloader.nextResultSet(); err != nil { 207 t.Fatalf("failed to retrieve data. err: %v", err) 208 } 209 row, err := downloader.next() 210 if err != nil { 211 t.Fatalf("failed to retrieve data. err: %v", err) 212 } 213 assertEqualRows(firstRows[i], row) 214 } 215 row, err := downloader.next() 216 if !assertEmptyChunkRow(row) { 217 t.Fatal("row should be empty") 218 } 219 if err != io.EOF { 220 t.Fatalf("failed to finish getting data. err: %v", err) 221 } 222 if downloader.hasNextResultSet() { 223 t.Error("downloader has next result set. expected none.") 224 } 225 if downloader.nextResultSet() != io.EOF { 226 t.Fatalf("failed to finish getting data. err: %v", err) 227 } 228 } 229 230 func TestStreamChunkDownloaderChunks(t *testing.T) { 231 chunks, responseChunks := generateStreamChunkDownloaderChunks([]string{"foo", "bar"}, 4, 4) 232 fetcher := &mockStreamChunkFetcher{chunks} 233 firstRows := generateStreamChunkRows(2, 4) 234 downloader := newStreamChunkDownloader( 235 context.Background(), 236 fetcher, 237 int64(len(firstRows)), 238 []execResponseRowType{}, 239 firstRows, 240 responseChunks) 241 if err := downloader.start(); err != nil { 242 t.Fatalf("chunk download start failed. err: %v", err) 243 } 244 for i := 0; i < len(firstRows); i++ { 245 if !downloader.hasNextResultSet() { 246 t.Error("failed to retrieve next result set") 247 } 248 if err := downloader.nextResultSet(); err != nil { 249 t.Fatalf("failed to retrieve data. err: %v", err) 250 } 251 row, err := downloader.next() 252 if err != nil { 253 t.Fatalf("failed to retrieve data. err: %v", err) 254 } 255 assertEqualRows(firstRows[i], row) 256 } 257 for _, chunk := range responseChunks { 258 for _, chunkRow := range chunks[chunk.URL] { 259 if !downloader.hasNextResultSet() { 260 t.Error("failed to retrieve next result set") 261 } 262 row, err := downloader.next() 263 if err != nil { 264 t.Fatalf("failed to retrieve data. err: %v", err) 265 } 266 assertEqualRows(chunkRow, row) 267 } 268 } 269 row, err := downloader.next() 270 if !assertEmptyChunkRow(row) { 271 t.Fatal("row should be empty") 272 } 273 if err != io.EOF { 274 t.Fatalf("failed to finish getting data. err: %v", err) 275 } 276 if downloader.hasNextResultSet() { 277 t.Error("downloader has next result set. expected none.") 278 } 279 if downloader.nextResultSet() != io.EOF { 280 t.Fatalf("failed to finish getting data. err: %v", err) 281 } 282 } 283 284 func TestCopyChunkStream(t *testing.T) { 285 foo := "foo" 286 bar := "bar" 287 288 r := strings.NewReader(`["foo","bar",null],["bar",null,"foo"],[]`) 289 c := make(chan []*string, 3) 290 if err := copyChunkStream(r, c); err != nil { 291 t.Fatalf("error while copying chunk stream. err: %v", err) 292 } 293 assertEqualRows([]*string{&foo, &bar, nil}, <-c) 294 assertEqualRows([]*string{&bar, nil, &foo}, <-c) 295 assertEqualRows([]*string{}, <-c) 296 } 297 298 func TestCopyChunkStreamInvalid(t *testing.T) { 299 var r io.Reader 300 var c chan []*string 301 var err error 302 303 r = strings.NewReader("oops") 304 c = make(chan []*string, 1) 305 if err = copyChunkStream(r, c); err == nil { 306 t.Fatalf("should fail to retrieve data. err: %v", err) 307 } 308 309 r = strings.NewReader(`[["foo"], ["bar"]]`) 310 c = make(chan []*string, 1) 311 if err = copyChunkStream(r, c); err == nil { 312 t.Fatalf("should fail to retrieve data. err: %v", err) 313 } 314 315 r = strings.NewReader(`{"foo": "bar"}`) 316 c = make(chan []*string, 1) 317 if err = copyChunkStream(r, c); err == nil { 318 t.Fatalf("should fail to retrieve data. err: %v", err) 319 } 320 } 321 322 func generateStreamChunkDownloaderChunks(urls []string, numRows, numCols int) (map[string][][]*string, []execResponseChunk) { 323 chunks := map[string][][]*string{} 324 var responseChunks []execResponseChunk 325 for _, url := range urls { 326 rows := generateStreamChunkRows(numRows, numCols) 327 chunks[url] = rows 328 responseChunks = append(responseChunks, execResponseChunk{url, len(rows), -1, -1}) 329 } 330 return chunks, responseChunks 331 } 332 333 func generateStreamChunkRows(numRows, numCols int) [][]*string { 334 var rows [][]*string 335 for i := 0; i < numRows; i++ { 336 var cols []*string 337 for j := 0; j < numCols; j++ { 338 col := fmt.Sprintf("%d", rand.Intn(1000)) 339 cols = append(cols, &col) 340 } 341 rows = append(rows, cols) 342 } 343 return rows 344 } 345 346 func assertEqualRows(expected []*string, actual interface{}) bool { 347 switch v := actual.(type) { 348 case chunkRowType: 349 for i := range expected { 350 if expected[i] != v.RowSet[i] { 351 return false 352 } 353 } 354 return len(expected) == len(v.RowSet) 355 case []*string: 356 for i := range expected { 357 if expected[i] != v[i] { 358 return false 359 } 360 } 361 return len(expected) == len(v) 362 } 363 return false 364 } 365 366 func assertEmptyChunkRow(row chunkRowType) bool { 367 return assertEqualRows(make([]*string, len(row.RowSet)), row) 368 } 369 370 func TestWithStreamDownloader(t *testing.T) { 371 ctx := WithStreamDownloader(context.Background()) 372 numrows := 100000 373 cnt := 0 374 var idx int 375 var v string 376 377 runDBTest(t, func(dbt *DBTest) { 378 dbt.mustExec(forceJSON) 379 rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows)) 380 defer rows.Close() 381 382 // Next() will block and wait until results are available 383 for rows.Next() { 384 if err := rows.Scan(&idx, &v); err != nil { 385 t.Fatal(err) 386 } 387 cnt++ 388 } 389 logger.Infof("NextResultSet: %v", rows.NextResultSet()) 390 391 if cnt != numrows { 392 t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt) 393 } 394 }) 395 } 396 397 func TestWithArrowBatches(t *testing.T) { 398 runSnowflakeConnTest(t, func(sct *SCTest) { 399 ctx := WithArrowBatches(sct.sc.ctx) 400 numrows := 3000 // approximately 6 ArrowBatch objects 401 402 pool := memory.NewCheckedAllocator(memory.DefaultAllocator) 403 defer pool.AssertSize(t, 0) 404 ctx = WithArrowAllocator(ctx, pool) 405 406 query := fmt.Sprintf(selectRandomGenerator, numrows) 407 rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) 408 defer rows.Close() 409 410 // getting result batches 411 batches, err := rows.(*snowflakeRows).GetArrowBatches() 412 if err != nil { 413 t.Error(err) 414 } 415 numBatches := len(batches) 416 maxWorkers := 10 // enough for 3000 rows 417 type count struct { 418 m sync.Mutex 419 recVal int 420 metaVal int 421 } 422 cnt := count{recVal: 0} 423 var wg sync.WaitGroup 424 chunks := make(chan int, numBatches) 425 426 // kicking off download workers - each of which will call fetch on a different result batch 427 for w := 1; w <= maxWorkers; w++ { 428 wg.Add(1) 429 go func(wg *sync.WaitGroup, chunks <-chan int) { 430 defer wg.Done() 431 432 for i := range chunks { 433 rec, err := batches[i].Fetch() 434 if err != nil { 435 t.Error(err) 436 } 437 for _, r := range *rec { 438 cnt.m.Lock() 439 cnt.recVal += int(r.NumRows()) 440 cnt.m.Unlock() 441 r.Release() 442 } 443 cnt.m.Lock() 444 cnt.metaVal += batches[i].rowCount 445 cnt.m.Unlock() 446 } 447 }(&wg, chunks) 448 } 449 for j := 0; j < numBatches; j++ { 450 chunks <- j 451 } 452 close(chunks) 453 454 // wait for workers to finish fetching and check row counts 455 wg.Wait() 456 if cnt.recVal != numrows { 457 t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) 458 } 459 if cnt.metaVal != numrows { 460 t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) 461 } 462 }) 463 } 464 465 func TestWithArrowBatchesAsync(t *testing.T) { 466 runSnowflakeConnTest(t, func(sct *SCTest) { 467 ctx := WithAsyncMode(sct.sc.ctx) 468 ctx = WithArrowBatches(ctx) 469 numrows := 50000 // approximately 10 ArrowBatch objects 470 471 pool := memory.NewCheckedAllocator(memory.DefaultAllocator) 472 defer pool.AssertSize(t, 0) 473 ctx = WithArrowAllocator(ctx, pool) 474 475 query := fmt.Sprintf(selectRandomGenerator, numrows) 476 rows := sct.mustQueryContext(ctx, query, []driver.NamedValue{}) 477 defer rows.Close() 478 479 // getting result batches 480 // this will fail if GetArrowBatches() is not a blocking call 481 batches, err := rows.(*snowflakeRows).GetArrowBatches() 482 if err != nil { 483 t.Error(err) 484 } 485 numBatches := len(batches) 486 maxWorkers := 10 487 type count struct { 488 m sync.Mutex 489 recVal int 490 metaVal int 491 } 492 cnt := count{recVal: 0} 493 var wg sync.WaitGroup 494 chunks := make(chan int, numBatches) 495 496 // kicking off download workers - each of which will call fetch on a different result batch 497 for w := 1; w <= maxWorkers; w++ { 498 wg.Add(1) 499 go func(wg *sync.WaitGroup, chunks <-chan int) { 500 defer wg.Done() 501 502 for i := range chunks { 503 rec, err := batches[i].Fetch() 504 if err != nil { 505 t.Error(err) 506 } 507 for _, r := range *rec { 508 cnt.m.Lock() 509 cnt.recVal += int(r.NumRows()) 510 cnt.m.Unlock() 511 r.Release() 512 } 513 cnt.m.Lock() 514 cnt.metaVal += batches[i].rowCount 515 cnt.m.Unlock() 516 } 517 }(&wg, chunks) 518 } 519 for j := 0; j < numBatches; j++ { 520 chunks <- j 521 } 522 close(chunks) 523 524 // wait for workers to finish fetching and check row counts 525 wg.Wait() 526 if cnt.recVal != numrows { 527 t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, cnt.recVal) 528 } 529 if cnt.metaVal != numrows { 530 t.Errorf("number of rows from arrow batch metadata didn't match. expected: %v, got: %v", numrows, cnt.metaVal) 531 } 532 }) 533 } 534 535 func TestQueryArrowStream(t *testing.T) { 536 runSnowflakeConnTest(t, func(sct *SCTest) { 537 numrows := 50000 // approximately 10 ArrowBatch objects 538 539 query := fmt.Sprintf(selectRandomGenerator, numrows) 540 loader, err := sct.sc.QueryArrowStream(sct.sc.ctx, query) 541 if err != nil { 542 t.Error(err) 543 } 544 545 if loader.TotalRows() != int64(numrows) { 546 t.Errorf("total numrows did not match expected, wanted %v, got %v", numrows, loader.TotalRows()) 547 } 548 549 batches, err := loader.GetBatches() 550 if err != nil { 551 t.Error(err) 552 } 553 554 numBatches := len(batches) 555 maxWorkers := 8 556 chunks := make(chan int, numBatches) 557 total := int64(0) 558 meta := int64(0) 559 560 var wg sync.WaitGroup 561 wg.Add(maxWorkers) 562 563 mem := memory.NewCheckedAllocator(memory.DefaultAllocator) 564 defer mem.AssertSize(t, 0) 565 566 for w := 0; w < maxWorkers; w++ { 567 go func() { 568 defer wg.Done() 569 570 for i := range chunks { 571 r, err := batches[i].GetStream(sct.sc.ctx) 572 if err != nil { 573 t.Error(err) 574 continue 575 } 576 rdr, err := ipc.NewReader(r, ipc.WithAllocator(mem)) 577 if err != nil { 578 t.Errorf("Error creating IPC reader for stream %d: %s", i, err) 579 r.Close() 580 continue 581 } 582 583 for rdr.Next() { 584 rec := rdr.Record() 585 atomic.AddInt64(&total, rec.NumRows()) 586 } 587 588 if rdr.Err() != nil { 589 t.Error(rdr.Err()) 590 } 591 rdr.Release() 592 if err := r.Close(); err != nil { 593 t.Error(err) 594 } 595 atomic.AddInt64(&meta, batches[i].NumRows()) 596 } 597 }() 598 } 599 600 for j := 0; j < numBatches; j++ { 601 chunks <- j 602 } 603 close(chunks) 604 wg.Wait() 605 606 if total != int64(numrows) { 607 t.Errorf("number of rows from records didn't match. expected: %v, got: %v", numrows, total) 608 } 609 if meta != int64(numrows) { 610 t.Errorf("number of rows from batch metadata didn't match. expected: %v, got: %v", numrows, total) 611 } 612 }) 613 } 614 615 func TestQueryArrowStreamDescribeOnly(t *testing.T) { 616 runSnowflakeConnTest(t, func(sct *SCTest) { 617 numrows := 50000 // approximately 10 ArrowBatch objects 618 619 query := fmt.Sprintf(selectRandomGenerator, numrows) 620 loader, err := sct.sc.QueryArrowStream(WithDescribeOnly(sct.sc.ctx), query) 621 assertNilF(t, err, "failed to run query") 622 623 if loader.TotalRows() != 0 { 624 t.Errorf("total numrows did not match expected, wanted 0, got %v", loader.TotalRows()) 625 } 626 627 batches, err := loader.GetBatches() 628 assertNilF(t, err, "failed to get result") 629 if len(batches) != 0 { 630 t.Errorf("batches length did not match expected, wanted 0, got %v", len(batches)) 631 } 632 633 rowtypes := loader.RowTypes() 634 if len(rowtypes) != 2 { 635 t.Errorf("rowTypes length did not match expected, wanted 2, got %v", len(rowtypes)) 636 } 637 }) 638 }