github.com/snowflakedb/gosnowflake@v1.9.0/chunk_downloader.go (about) 1 // Copyright (c) 2021-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bufio" 7 "compress/gzip" 8 "context" 9 "encoding/json" 10 "fmt" 11 "io" 12 "math/rand" 13 "net" 14 "net/http" 15 "net/url" 16 "strings" 17 "sync" 18 "time" 19 20 "github.com/apache/arrow/go/v15/arrow" 21 "github.com/apache/arrow/go/v15/arrow/ipc" 22 "github.com/apache/arrow/go/v15/arrow/memory" 23 ) 24 25 type chunkDownloader interface { 26 totalUncompressedSize() (acc int64) 27 hasNextResultSet() bool 28 nextResultSet() error 29 start() error 30 next() (chunkRowType, error) 31 reset() 32 getChunkMetas() []execResponseChunk 33 getQueryResultFormat() resultFormat 34 getRowType() []execResponseRowType 35 setNextChunkDownloader(downloader chunkDownloader) 36 getNextChunkDownloader() chunkDownloader 37 getArrowBatches() []*ArrowBatch 38 } 39 40 type snowflakeChunkDownloader struct { 41 sc *snowflakeConn 42 ctx context.Context 43 pool memory.Allocator 44 Total int64 45 TotalRowIndex int64 46 CellCount int 47 CurrentChunk []chunkRowType 48 CurrentChunkIndex int 49 CurrentChunkSize int 50 CurrentIndex int 51 ChunkHeader map[string]string 52 ChunkMetas []execResponseChunk 53 Chunks map[int][]chunkRowType 54 ChunksChan chan int 55 ChunksError chan *chunkError 56 ChunksErrorCounter int 57 ChunksFinalErrors []*chunkError 58 ChunksMutex *sync.Mutex 59 DoneDownloadCond *sync.Cond 60 FirstBatch *ArrowBatch 61 NextDownloader chunkDownloader 62 Qrmk string 63 QueryResultFormat string 64 ArrowBatches []*ArrowBatch 65 RowSet rowSetType 66 FuncDownload func(context.Context, *snowflakeChunkDownloader, int) 67 FuncDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error 68 FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) 69 } 70 71 func (scd *snowflakeChunkDownloader) totalUncompressedSize() (acc int64) { 72 for _, c := range scd.ChunkMetas { 73 acc += c.UncompressedSize 74 } 75 return 76 } 77 78 func (scd *snowflakeChunkDownloader) hasNextResultSet() bool { 79 if len(scd.ChunkMetas) == 0 && scd.NextDownloader == nil { 80 return false // no extra chunk 81 } 82 // next result set exists if current chunk has remaining result sets or there is another downloader 83 return scd.CurrentChunkIndex < len(scd.ChunkMetas) || scd.NextDownloader != nil 84 } 85 86 func (scd *snowflakeChunkDownloader) nextResultSet() error { 87 // no error at all times as the next chunk/resultset is automatically read 88 if scd.CurrentChunkIndex < len(scd.ChunkMetas) { 89 return nil 90 } 91 return io.EOF 92 } 93 94 func (scd *snowflakeChunkDownloader) start() error { 95 if usesArrowBatches(scd.ctx) { 96 return scd.startArrowBatches() 97 } 98 scd.CurrentChunkSize = len(scd.RowSet.JSON) // cache the size 99 scd.CurrentIndex = -1 // initial chunks idx 100 scd.CurrentChunkIndex = -1 // initial chunk 101 102 scd.CurrentChunk = make([]chunkRowType, scd.CurrentChunkSize) 103 populateJSONRowSet(scd.CurrentChunk, scd.RowSet.JSON) 104 105 if scd.getQueryResultFormat() == arrowFormat && scd.RowSet.RowSetBase64 != "" { 106 // if the rowsetbase64 retrieved from the server is empty, move on to downloading chunks 107 var err error 108 var loc *time.Location 109 if scd.sc != nil && scd.sc.cfg != nil { 110 loc = getCurrentLocation(scd.sc.cfg.Params) 111 } 112 firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) 113 if err != nil { 114 return err 115 } 116 higherPrecision := higherPrecisionEnabled(scd.ctx) 117 scd.CurrentChunk, err = firstArrowChunk.decodeArrowChunk(scd.RowSet.RowType, higherPrecision) 118 scd.CurrentChunkSize = firstArrowChunk.rowCount 119 if err != nil { 120 return err 121 } 122 } 123 124 // start downloading chunks if exists 125 chunkMetaLen := len(scd.ChunkMetas) 126 if chunkMetaLen > 0 { 127 logger.Debugf("MaxChunkDownloadWorkers: %v", MaxChunkDownloadWorkers) 128 logger.Debugf("chunks: %v, total bytes: %d", chunkMetaLen, scd.totalUncompressedSize()) 129 scd.ChunksMutex = &sync.Mutex{} 130 scd.DoneDownloadCond = sync.NewCond(scd.ChunksMutex) 131 scd.Chunks = make(map[int][]chunkRowType) 132 scd.ChunksChan = make(chan int, chunkMetaLen) 133 scd.ChunksError = make(chan *chunkError, MaxChunkDownloadWorkers) 134 for i := 0; i < chunkMetaLen; i++ { 135 chunk := scd.ChunkMetas[i] 136 logger.Debugf("add chunk to channel ChunksChan: %v, URL: %v, RowCount: %v, UncompressedSize: %v, ChunkResultFormat: %v", 137 i+1, chunk.URL, chunk.RowCount, chunk.UncompressedSize, scd.QueryResultFormat) 138 scd.ChunksChan <- i 139 } 140 for i := 0; i < intMin(MaxChunkDownloadWorkers, chunkMetaLen); i++ { 141 scd.schedule() 142 } 143 } 144 return nil 145 } 146 147 func (scd *snowflakeChunkDownloader) schedule() { 148 select { 149 case nextIdx := <-scd.ChunksChan: 150 logger.Infof("schedule chunk: %v", nextIdx+1) 151 go scd.FuncDownload(scd.ctx, scd, nextIdx) 152 default: 153 // no more download 154 logger.Info("no more download") 155 } 156 } 157 158 func (scd *snowflakeChunkDownloader) checkErrorRetry() (err error) { 159 select { 160 case errc := <-scd.ChunksError: 161 if scd.ChunksErrorCounter < maxChunkDownloaderErrorCounter && 162 errc.Error != context.Canceled && 163 errc.Error != context.DeadlineExceeded { 164 // add the index to the chunks channel so that the download will be retried. 165 go scd.FuncDownload(scd.ctx, scd, errc.Index) 166 scd.ChunksErrorCounter++ 167 logger.Warningf("chunk idx: %v, err: %v. retrying (%v/%v)...", 168 errc.Index, errc.Error, scd.ChunksErrorCounter, maxChunkDownloaderErrorCounter) 169 } else { 170 scd.ChunksFinalErrors = append(scd.ChunksFinalErrors, errc) 171 logger.Warningf("chunk idx: %v, err: %v. no further retry", errc.Index, errc.Error) 172 return errc.Error 173 } 174 default: 175 logger.Info("no error is detected.") 176 } 177 return nil 178 } 179 180 func (scd *snowflakeChunkDownloader) next() (chunkRowType, error) { 181 for { 182 scd.CurrentIndex++ 183 if scd.CurrentIndex < scd.CurrentChunkSize { 184 return scd.CurrentChunk[scd.CurrentIndex], nil 185 } 186 scd.CurrentChunkIndex++ // next chunk 187 scd.CurrentIndex = -1 // reset 188 if scd.CurrentChunkIndex >= len(scd.ChunkMetas) { 189 break 190 } 191 192 scd.ChunksMutex.Lock() 193 if scd.CurrentChunkIndex > 0 { 194 scd.Chunks[scd.CurrentChunkIndex-1] = nil // detach the previously used chunk 195 } 196 197 for scd.Chunks[scd.CurrentChunkIndex] == nil { 198 logger.Debugf("waiting for chunk idx: %v/%v", 199 scd.CurrentChunkIndex+1, len(scd.ChunkMetas)) 200 201 if err := scd.checkErrorRetry(); err != nil { 202 scd.ChunksMutex.Unlock() 203 return chunkRowType{}, err 204 } 205 206 // wait for chunk downloader goroutine to broadcast the event, 207 // 1) one chunk download finishes or 2) an error occurs. 208 scd.DoneDownloadCond.Wait() 209 } 210 logger.Debugf("ready: chunk %v", scd.CurrentChunkIndex+1) 211 scd.CurrentChunk = scd.Chunks[scd.CurrentChunkIndex] 212 scd.ChunksMutex.Unlock() 213 scd.CurrentChunkSize = len(scd.CurrentChunk) 214 215 // kick off the next download 216 scd.schedule() 217 } 218 219 logger.Debugf("no more data") 220 if len(scd.ChunkMetas) > 0 { 221 close(scd.ChunksError) 222 close(scd.ChunksChan) 223 } 224 return chunkRowType{}, io.EOF 225 } 226 227 func (scd *snowflakeChunkDownloader) reset() { 228 scd.Chunks = nil // detach all chunks. No way to go backward without reinitialize it. 229 } 230 231 func (scd *snowflakeChunkDownloader) getChunkMetas() []execResponseChunk { 232 return scd.ChunkMetas 233 } 234 235 func (scd *snowflakeChunkDownloader) getQueryResultFormat() resultFormat { 236 return resultFormat(scd.QueryResultFormat) 237 } 238 239 func (scd *snowflakeChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) { 240 scd.NextDownloader = nextDownloader 241 } 242 243 func (scd *snowflakeChunkDownloader) getNextChunkDownloader() chunkDownloader { 244 return scd.NextDownloader 245 } 246 247 func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType { 248 return scd.RowSet.RowType 249 } 250 251 func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch { 252 if scd.FirstBatch == nil || scd.FirstBatch.rec == nil { 253 return scd.ArrowBatches 254 } 255 return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...) 256 } 257 258 func getChunk( 259 ctx context.Context, 260 sc *snowflakeConn, 261 fullURL string, 262 headers map[string]string, 263 timeout time.Duration) ( 264 *http.Response, error, 265 ) { 266 u, err := url.Parse(fullURL) 267 if err != nil { 268 return nil, err 269 } 270 return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.rest.MaxRetryCount, sc.currentTimeProvider, sc.cfg).execute() 271 } 272 273 func (scd *snowflakeChunkDownloader) startArrowBatches() error { 274 var loc *time.Location 275 if scd.sc != nil && scd.sc.cfg != nil { 276 loc = getCurrentLocation(scd.sc.cfg.Params) 277 } 278 if scd.RowSet.RowSetBase64 != "" { 279 var err error 280 firstArrowChunk, err := buildFirstArrowChunk(scd.RowSet.RowSetBase64, loc, scd.pool) 281 if err != nil { 282 return err 283 } 284 scd.FirstBatch = &ArrowBatch{ 285 idx: 0, 286 scd: scd, 287 funcDownloadHelper: scd.FuncDownloadHelper, 288 loc: loc, 289 } 290 // decode first chunk if possible 291 if firstArrowChunk.allocator != nil { 292 scd.FirstBatch.rec, err = firstArrowChunk.decodeArrowBatch(scd) 293 if err != nil { 294 return err 295 } 296 } 297 } 298 chunkMetaLen := len(scd.ChunkMetas) 299 scd.ArrowBatches = make([]*ArrowBatch, chunkMetaLen) 300 for i := range scd.ArrowBatches { 301 scd.ArrowBatches[i] = &ArrowBatch{ 302 idx: i, 303 scd: scd, 304 funcDownloadHelper: scd.FuncDownloadHelper, 305 loc: loc, 306 } 307 } 308 return nil 309 } 310 311 /* largeResultSetReader is a reader that wraps the large result set with leading and tailing brackets. */ 312 type largeResultSetReader struct { 313 status int 314 body io.Reader 315 } 316 317 func (r *largeResultSetReader) Read(p []byte) (n int, err error) { 318 if r.status == 0 { 319 p[0] = 0x5b // initial 0x5b ([) 320 r.status = 1 321 return 1, nil 322 } 323 if r.status == 1 { 324 var len int 325 len, err = r.body.Read(p) 326 if err == io.EOF { 327 r.status = 2 328 return len, nil 329 } 330 if err != nil { 331 return 0, err 332 } 333 return len, nil 334 } 335 if r.status == 2 { 336 p[0] = 0x5d // tail 0x5d (]) 337 r.status = 3 338 return 1, nil 339 } 340 // ensure no data and EOF 341 return 0, io.EOF 342 } 343 344 func downloadChunk(ctx context.Context, scd *snowflakeChunkDownloader, idx int) { 345 logger.Infof("download start chunk: %v", idx+1) 346 defer scd.DoneDownloadCond.Broadcast() 347 348 if err := scd.FuncDownloadHelper(ctx, scd, idx); err != nil { 349 logger.Errorf( 350 "failed to extract HTTP response body. URL: %v, err: %v", scd.ChunkMetas[idx].URL, err) 351 scd.ChunksError <- &chunkError{Index: idx, Error: err} 352 } else if scd.ctx.Err() == context.Canceled || scd.ctx.Err() == context.DeadlineExceeded { 353 scd.ChunksError <- &chunkError{Index: idx, Error: scd.ctx.Err()} 354 } 355 } 356 357 func downloadChunkHelper(ctx context.Context, scd *snowflakeChunkDownloader, idx int) error { 358 headers := make(map[string]string) 359 if len(scd.ChunkHeader) > 0 { 360 logger.Debug("chunk header is provided.") 361 for k, v := range scd.ChunkHeader { 362 logger.Debugf("adding header: %v, value: %v", k, v) 363 364 headers[k] = v 365 } 366 } else { 367 headers[headerSseCAlgorithm] = headerSseCAes 368 headers[headerSseCKey] = scd.Qrmk 369 } 370 371 resp, err := scd.FuncGet(ctx, scd.sc, scd.ChunkMetas[idx].URL, headers, scd.sc.rest.RequestTimeout) 372 if err != nil { 373 return err 374 } 375 bufStream := bufio.NewReader(resp.Body) 376 defer resp.Body.Close() 377 logger.Debugf("response returned chunk: %v for URL: %v", idx+1, scd.ChunkMetas[idx].URL) 378 if resp.StatusCode != http.StatusOK { 379 b, err := io.ReadAll(bufStream) 380 if err != nil { 381 return err 382 } 383 logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, scd.ChunkMetas[idx].URL, b) 384 logger.Infof("Header: %v", resp.Header) 385 return &SnowflakeError{ 386 Number: ErrFailedToGetChunk, 387 SQLState: SQLStateConnectionFailure, 388 Message: errMsgFailedToGetChunk, 389 MessageArgs: []interface{}{idx}, 390 } 391 } 392 return decodeChunk(scd, idx, bufStream) 393 } 394 395 func decodeChunk(scd *snowflakeChunkDownloader, idx int, bufStream *bufio.Reader) (err error) { 396 gzipMagic, err := bufStream.Peek(2) 397 if err != nil { 398 return err 399 } 400 start := time.Now() 401 var source io.Reader 402 if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { 403 // detects and uncompresses Gzip format data 404 bufStream0, err := gzip.NewReader(bufStream) 405 if err != nil { 406 return err 407 } 408 defer bufStream0.Close() 409 source = bufStream0 410 } else { 411 source = bufStream 412 } 413 st := &largeResultSetReader{ 414 status: 0, 415 body: source, 416 } 417 var respd []chunkRowType 418 if scd.getQueryResultFormat() != arrowFormat { 419 var decRespd [][]*string 420 if !CustomJSONDecoderEnabled { 421 dec := json.NewDecoder(st) 422 for { 423 if err = dec.Decode(&decRespd); err == io.EOF { 424 break 425 } else if err != nil { 426 return err 427 } 428 } 429 } else { 430 decRespd, err = decodeLargeChunk(st, scd.ChunkMetas[idx].RowCount, scd.CellCount) 431 if err != nil { 432 return err 433 } 434 } 435 respd = make([]chunkRowType, len(decRespd)) 436 populateJSONRowSet(respd, decRespd) 437 } else { 438 ipcReader, err := ipc.NewReader(source, ipc.WithAllocator(scd.pool)) 439 if err != nil { 440 return err 441 } 442 var loc *time.Location 443 if scd.sc != nil && scd.sc.cfg != nil { 444 loc = getCurrentLocation(scd.sc.cfg.Params) 445 } 446 arc := arrowResultChunk{ 447 ipcReader, 448 0, 449 loc, 450 scd.pool, 451 } 452 if usesArrowBatches(scd.ctx) { 453 if scd.ArrowBatches[idx].rec, err = arc.decodeArrowBatch(scd); err != nil { 454 return err 455 } 456 // updating metadata 457 scd.ArrowBatches[idx].rowCount = countArrowBatchRows(scd.ArrowBatches[idx].rec) 458 return nil 459 } 460 highPrec := higherPrecisionEnabled(scd.ctx) 461 respd, err = arc.decodeArrowChunk(scd.RowSet.RowType, highPrec) 462 if err != nil { 463 return err 464 } 465 } 466 logger.Debugf( 467 "decoded %d rows w/ %d bytes in %s (chunk %v)", 468 scd.ChunkMetas[idx].RowCount, 469 scd.ChunkMetas[idx].UncompressedSize, 470 time.Since(start), idx+1, 471 ) 472 473 scd.ChunksMutex.Lock() 474 defer scd.ChunksMutex.Unlock() 475 scd.Chunks[idx] = respd 476 return nil 477 } 478 479 func populateJSONRowSet(dst []chunkRowType, src [][]*string) { 480 // populate string rowset from src to dst's chunkRowType struct's RowSet field 481 for i, row := range src { 482 dst[i].RowSet = row 483 } 484 } 485 486 type streamChunkDownloader struct { 487 ctx context.Context 488 id int64 489 fetcher streamChunkFetcher 490 readErr error 491 rowStream chan []*string 492 Total int64 493 ChunkMetas []execResponseChunk 494 NextDownloader chunkDownloader 495 RowSet rowSetType 496 } 497 498 func (scd *streamChunkDownloader) totalUncompressedSize() (acc int64) { 499 return -1 500 } 501 502 func (scd *streamChunkDownloader) hasNextResultSet() bool { 503 return scd.readErr == nil 504 } 505 506 func (scd *streamChunkDownloader) nextResultSet() error { 507 return scd.readErr 508 } 509 510 func (scd *streamChunkDownloader) start() error { 511 go func() { 512 readErr := io.EOF 513 514 logger.WithContext(scd.ctx).Infof( 515 "start downloading. downloader id: %v, %v/%v rows, %v chunks", 516 scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas)) 517 t := time.Now() 518 519 defer func() { 520 if readErr == io.EOF { 521 logger.WithContext(scd.ctx).Infof("downloading done. downloader id: %v", scd.id) 522 } else { 523 logger.WithContext(scd.ctx).Debugf("downloading error. downloader id: %v", scd.id) 524 } 525 scd.readErr = readErr 526 close(scd.rowStream) 527 528 if r := recover(); r != nil { 529 if err, ok := r.(error); ok { 530 readErr = err 531 } else { 532 readErr = fmt.Errorf("%v", r) 533 } 534 } 535 }() 536 537 logger.WithContext(scd.ctx).Infof("sending initial set of rows in %vms", time.Since(t).Microseconds()) 538 t = time.Now() 539 for _, row := range scd.RowSet.JSON { 540 scd.rowStream <- row 541 } 542 scd.RowSet.JSON = nil 543 544 // Download and parse one chunk at a time. The fetcher will send each 545 // parsed row to the row stream. When an error occurs, the fetcher will 546 // stop writing to the row stream so we can stop processing immediately 547 for i, chunk := range scd.ChunkMetas { 548 logger.WithContext(scd.ctx).Infof("starting chunk fetch %d (%d rows)", i, chunk.RowCount) 549 if err := scd.fetcher.fetch(chunk.URL, scd.rowStream); err != nil { 550 logger.WithContext(scd.ctx).Debugf( 551 "failed chunk fetch %d: %#v, downloader id: %v, %v/%v rows, %v chunks", 552 i, err, scd.id, len(scd.RowSet.RowType), scd.Total, len(scd.ChunkMetas)) 553 readErr = fmt.Errorf("chunk fetch: %w", err) 554 break 555 } 556 logger.WithContext(scd.ctx).Infof("fetched chunk %d (%d rows) in %vms", i, chunk.RowCount, time.Since(t).Microseconds()) 557 t = time.Now() 558 } 559 }() 560 return nil 561 } 562 563 func (scd *streamChunkDownloader) next() (chunkRowType, error) { 564 if row, ok := <-scd.rowStream; ok { 565 return chunkRowType{RowSet: row}, nil 566 } 567 return chunkRowType{}, scd.readErr 568 } 569 570 func (scd *streamChunkDownloader) reset() {} 571 572 func (scd *streamChunkDownloader) getChunkMetas() []execResponseChunk { 573 return scd.ChunkMetas 574 } 575 576 func (scd *streamChunkDownloader) getQueryResultFormat() resultFormat { 577 return jsonFormat 578 } 579 580 func (scd *streamChunkDownloader) setNextChunkDownloader(nextDownloader chunkDownloader) { 581 scd.NextDownloader = nextDownloader 582 } 583 584 func (scd *streamChunkDownloader) getNextChunkDownloader() chunkDownloader { 585 return scd.NextDownloader 586 } 587 588 func (scd *streamChunkDownloader) getRowType() []execResponseRowType { 589 return scd.RowSet.RowType 590 } 591 592 func (scd *streamChunkDownloader) getArrowBatches() []*ArrowBatch { 593 return nil 594 } 595 596 func useStreamDownloader(ctx context.Context) bool { 597 val := ctx.Value(streamChunkDownload) 598 if val == nil { 599 return false 600 } 601 s, ok := val.(bool) 602 return s && ok 603 } 604 605 type streamChunkFetcher interface { 606 fetch(url string, rows chan<- []*string) error 607 } 608 609 type httpStreamChunkFetcher struct { 610 ctx context.Context 611 client *http.Client 612 clientIP net.IP 613 headers map[string]string 614 qrmk string 615 } 616 617 func newStreamChunkDownloader( 618 ctx context.Context, 619 fetcher streamChunkFetcher, 620 total int64, 621 rowType []execResponseRowType, 622 firstRows [][]*string, 623 chunks []execResponseChunk, 624 ) *streamChunkDownloader { 625 return &streamChunkDownloader{ 626 ctx: ctx, 627 id: rand.Int63(), 628 fetcher: fetcher, 629 readErr: nil, 630 rowStream: make(chan []*string), 631 Total: total, 632 ChunkMetas: chunks, 633 RowSet: rowSetType{RowType: rowType, JSON: firstRows}, 634 } 635 } 636 637 func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error { 638 if len(f.headers) == 0 { 639 f.headers = map[string]string{ 640 headerSseCAlgorithm: headerSseCAes, 641 headerSseCKey: f.qrmk, 642 } 643 } 644 645 fullURL, err := url.Parse(URL) 646 if err != nil { 647 return err 648 } 649 res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, 0, defaultTimeProvider, nil).execute() 650 if err != nil { 651 return err 652 } 653 defer res.Body.Close() 654 if res.StatusCode != http.StatusOK { 655 b, err := io.ReadAll(res.Body) 656 if err != nil { 657 return err 658 } 659 return fmt.Errorf("status (%d): %s", res.StatusCode, string(b)) 660 } 661 if err = copyChunkStream(res.Body, rows); err != nil { 662 return fmt.Errorf("read: %w", err) 663 } 664 return nil 665 } 666 667 func copyChunkStream(body io.Reader, rows chan<- []*string) error { 668 bufStream := bufio.NewReader(body) 669 gzipMagic, err := bufStream.Peek(2) 670 if err != nil { 671 return err 672 } 673 var source io.Reader 674 if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { 675 // detect and decompress Gzip format data 676 bufStream0, err := gzip.NewReader(bufStream) 677 if err != nil { 678 return err 679 } 680 defer bufStream0.Close() 681 source = bufStream0 682 } else { 683 source = bufStream 684 } 685 r := io.MultiReader(strings.NewReader("["), source, strings.NewReader("]")) 686 dec := json.NewDecoder(r) 687 openToken := json.Delim('[') 688 closeToken := json.Delim(']') 689 for { 690 if t, err := dec.Token(); err == io.EOF { 691 break 692 } else if err != nil { 693 return fmt.Errorf("delim open: %w", err) 694 } else if t != openToken { 695 return fmt.Errorf("delim open: got %T", t) 696 } 697 for dec.More() { 698 var row []*string 699 if err = dec.Decode(&row); err != nil { 700 return fmt.Errorf("decode: %w", err) 701 } 702 rows <- row 703 } 704 if t, err := dec.Token(); err != nil { 705 return fmt.Errorf("delim close: %w", err) 706 } else if t != closeToken { 707 return fmt.Errorf("delim close: got %T", t) 708 } 709 } 710 return nil 711 } 712 713 // ArrowBatch object represents a chunk of data, or subset of rows, retrievable in arrow.Record format 714 type ArrowBatch struct { 715 rec *[]arrow.Record 716 idx int 717 rowCount int 718 scd *snowflakeChunkDownloader 719 funcDownloadHelper func(context.Context, *snowflakeChunkDownloader, int) error 720 ctx context.Context 721 loc *time.Location 722 } 723 724 // WithContext sets the context which will be used for this ArrowBatch. 725 func (rb *ArrowBatch) WithContext(ctx context.Context) *ArrowBatch { 726 rb.ctx = ctx 727 return rb 728 } 729 730 // Fetch returns an array of records representing a chunk in the query 731 func (rb *ArrowBatch) Fetch() (*[]arrow.Record, error) { 732 // chunk has already been downloaded 733 if rb.rec != nil { 734 // updating metadata 735 rb.rowCount = countArrowBatchRows(rb.rec) 736 return rb.rec, nil 737 } 738 var ctx context.Context 739 if rb.ctx != nil { 740 ctx = rb.ctx 741 } else { 742 ctx = context.Background() 743 } 744 if err := rb.funcDownloadHelper(ctx, rb.scd, rb.idx); err != nil { 745 return nil, err 746 } 747 return rb.rec, nil 748 } 749 750 // GetRowCount returns the number of rows in an arrow batch 751 func (rb *ArrowBatch) GetRowCount() int { 752 return rb.rowCount 753 } 754 755 func getAllocator(ctx context.Context) memory.Allocator { 756 pool, ok := ctx.Value(arrowAlloc).(memory.Allocator) 757 if !ok { 758 return memory.DefaultAllocator 759 } 760 return pool 761 } 762 763 func usesArrowBatches(ctx context.Context) bool { 764 val := ctx.Value(arrowBatches) 765 if val == nil { 766 return false 767 } 768 a, ok := val.(bool) 769 return a && ok 770 } 771 772 func countArrowBatchRows(recs *[]arrow.Record) int { 773 var cnt int 774 for _, r := range *recs { 775 cnt += int(r.NumRows()) 776 } 777 return cnt 778 }