github.com/snowflakedb/gosnowflake@v1.9.0/connection.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 import ( 6 "bufio" 7 "bytes" 8 "compress/gzip" 9 "context" 10 "database/sql" 11 "database/sql/driver" 12 "encoding/base64" 13 "encoding/json" 14 "io" 15 "net/http" 16 "net/url" 17 "os" 18 "regexp" 19 "strconv" 20 "strings" 21 "sync" 22 "sync/atomic" 23 "time" 24 25 "github.com/apache/arrow/go/v15/arrow/ipc" 26 ) 27 28 const ( 29 httpHeaderContentType = "Content-Type" 30 httpHeaderAccept = "accept" 31 httpHeaderUserAgent = "User-Agent" 32 httpHeaderServiceName = "X-Snowflake-Service" 33 httpHeaderContentLength = "Content-Length" 34 httpHeaderHost = "Host" 35 httpHeaderValueOctetStream = "application/octet-stream" 36 httpHeaderContentEncoding = "Content-Encoding" 37 httpClientAppID = "CLIENT_APP_ID" 38 httpClientAppVersion = "CLIENT_APP_VERSION" 39 ) 40 41 const ( 42 statementTypeIDSelect = int64(0x1000) 43 statementTypeIDDml = int64(0x3000) 44 statementTypeIDMultiTableInsert = statementTypeIDDml + int64(0x500) 45 statementTypeIDMultistatement = int64(0xA000) 46 ) 47 48 const ( 49 sessionClientSessionKeepAlive = "client_session_keep_alive" 50 sessionClientValidateDefaultParameters = "CLIENT_VALIDATE_DEFAULT_PARAMETERS" 51 sessionArrayBindStageThreshold = "client_stage_array_binding_threshold" 52 serviceName = "service_name" 53 ) 54 55 type resultType string 56 57 const ( 58 snowflakeResultType contextKey = "snowflakeResultType" 59 execResultType resultType = "exec" 60 queryResultType resultType = "query" 61 ) 62 63 type execKey string 64 65 const ( 66 executionType execKey = "executionType" 67 executionTypeStatement string = "statement" 68 ) 69 70 const privateLinkSuffix = "privatelink.snowflakecomputing.com" 71 72 type snowflakeConn struct { 73 ctx context.Context 74 cfg *Config 75 rest *snowflakeRestful 76 SequenceCounter uint64 77 telemetry *snowflakeTelemetry 78 internal InternalClient 79 queryContextCache *queryContextCache 80 currentTimeProvider currentTimeProvider 81 } 82 83 var ( 84 queryIDPattern = `[\w\-_]+` 85 queryIDRegexp = regexp.MustCompile(queryIDPattern) 86 ) 87 88 func (sc *snowflakeConn) exec( 89 ctx context.Context, 90 query string, 91 noResult bool, 92 isInternal bool, 93 describeOnly bool, 94 bindings []driver.NamedValue) ( 95 *execResponse, error) { 96 var err error 97 counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter 98 99 queryContext, err := buildQueryContext(sc.queryContextCache) 100 if err != nil { 101 logger.Errorf("error while building query context: %v", err) 102 } 103 req := execRequest{ 104 SQLText: query, 105 AsyncExec: noResult, 106 Parameters: map[string]interface{}{}, 107 IsInternal: isInternal, 108 DescribeOnly: describeOnly, 109 SequenceID: counter, 110 QueryContext: queryContext, 111 } 112 if key := ctx.Value(multiStatementCount); key != nil { 113 req.Parameters[string(multiStatementCount)] = key 114 } 115 if tag := ctx.Value(queryTag); tag != nil { 116 req.Parameters[string(queryTag)] = tag 117 } 118 logger.WithContext(ctx).Infof("parameters: %v", req.Parameters) 119 120 // handle bindings, if required 121 requestID := getOrGenerateRequestIDFromContext(ctx) 122 if len(bindings) > 0 { 123 if err = sc.processBindings(ctx, bindings, describeOnly, requestID, &req); err != nil { 124 return nil, err 125 } 126 } 127 logger.WithContext(ctx).Infof("bindings: %v", req.Bindings) 128 129 // populate headers 130 headers := getHeaders() 131 if isFileTransfer(query) { 132 headers[httpHeaderAccept] = headerContentTypeApplicationJSON 133 } 134 paramsMutex.Lock() 135 if serviceName, ok := sc.cfg.Params[serviceName]; ok { 136 headers[httpHeaderServiceName] = *serviceName 137 } 138 paramsMutex.Unlock() 139 140 jsonBody, err := json.Marshal(req) 141 if err != nil { 142 return nil, err 143 } 144 145 data, err := sc.rest.FuncPostQuery(ctx, sc.rest, &url.Values{}, headers, 146 jsonBody, sc.rest.RequestTimeout, requestID, sc.cfg) 147 if err != nil { 148 return data, err 149 } 150 code := -1 151 if data.Code != "" { 152 code, err = strconv.Atoi(data.Code) 153 if err != nil { 154 return data, err 155 } 156 } 157 logger.WithContext(ctx).Infof("Success: %v, Code: %v", data.Success, code) 158 if !data.Success { 159 err = (populateErrorFields(code, data)).exceptionTelemetry(sc) 160 return nil, err 161 } 162 163 if !sc.cfg.DisableQueryContextCache && data.Data.QueryContext != nil { 164 queryContext, err := extractQueryContext(data) 165 if err != nil { 166 logger.Errorf("error while decoding query context: ", err) 167 } else { 168 sc.queryContextCache.add(sc, queryContext.Entries...) 169 } 170 } 171 172 // handle PUT/GET commands 173 if isFileTransfer(query) { 174 data, err = sc.processFileTransfer(ctx, data, query, isInternal) 175 if err != nil { 176 return nil, err 177 } 178 } 179 180 logger.WithContext(ctx).Info("Exec/Query SUCCESS") 181 if data.Data.FinalDatabaseName != "" { 182 sc.cfg.Database = data.Data.FinalDatabaseName 183 } 184 if data.Data.FinalSchemaName != "" { 185 sc.cfg.Schema = data.Data.FinalSchemaName 186 } 187 if data.Data.FinalWarehouseName != "" { 188 sc.cfg.Warehouse = data.Data.FinalWarehouseName 189 } 190 if data.Data.FinalRoleName != "" { 191 sc.cfg.Role = data.Data.FinalRoleName 192 } 193 sc.populateSessionParameters(data.Data.Parameters) 194 return data, err 195 } 196 197 func extractQueryContext(data *execResponse) (queryContext, error) { 198 var queryContext queryContext 199 err := json.Unmarshal(data.Data.QueryContext, &queryContext) 200 return queryContext, err 201 } 202 203 func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) { 204 rqc := requestQueryContext{} 205 if qcc == nil || len(qcc.entries) == 0 { 206 logger.Debugf("empty qcc") 207 return rqc, nil 208 } 209 for _, qce := range qcc.entries { 210 contextData := contextData{} 211 if qce.Context == "" { 212 contextData.Base64Data = qce.Context 213 } 214 rqc.Entries = append(rqc.Entries, requestQueryContextEntry{ 215 ID: qce.ID, 216 Priority: qce.Priority, 217 Timestamp: qce.Timestamp, 218 Context: contextData, 219 }) 220 } 221 return rqc, nil 222 } 223 224 func (sc *snowflakeConn) Begin() (driver.Tx, error) { 225 return sc.BeginTx(sc.ctx, driver.TxOptions{}) 226 } 227 228 func (sc *snowflakeConn) BeginTx( 229 ctx context.Context, 230 opts driver.TxOptions) ( 231 driver.Tx, error) { 232 logger.WithContext(ctx).Info("BeginTx") 233 if opts.ReadOnly { 234 return nil, (&SnowflakeError{ 235 Number: ErrNoReadOnlyTransaction, 236 SQLState: SQLStateFeatureNotSupported, 237 Message: errMsgNoReadOnlyTransaction, 238 }).exceptionTelemetry(sc) 239 } 240 if int(opts.Isolation) != int(sql.LevelDefault) { 241 return nil, (&SnowflakeError{ 242 Number: ErrNoDefaultTransactionIsolationLevel, 243 SQLState: SQLStateFeatureNotSupported, 244 Message: errMsgNoDefaultTransactionIsolationLevel, 245 }).exceptionTelemetry(sc) 246 } 247 if sc.rest == nil { 248 return nil, driver.ErrBadConn 249 } 250 isDesc := isDescribeOnly(ctx) 251 if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */ 252 false /* isInternal */, isDesc, nil); err != nil { 253 return nil, err 254 } 255 return &snowflakeTx{sc, ctx}, nil 256 } 257 258 func (sc *snowflakeConn) cleanup() { 259 // must flush log buffer while the process is running. 260 if sc.rest != nil && sc.rest.Client != nil { 261 sc.rest.Client.CloseIdleConnections() 262 } 263 sc.rest = nil 264 sc.cfg = nil 265 } 266 267 func (sc *snowflakeConn) Close() (err error) { 268 logger.WithContext(sc.ctx).Infoln("Close") 269 sc.telemetry.sendBatch() 270 sc.stopHeartBeat() 271 defer sc.cleanup() 272 273 if sc.cfg != nil && !sc.cfg.KeepSessionAlive { 274 if err = sc.rest.FuncCloseSession(sc.ctx, sc.rest, sc.rest.RequestTimeout); err != nil { 275 logger.Error(err) 276 } 277 } 278 return nil 279 } 280 281 func (sc *snowflakeConn) PrepareContext( 282 ctx context.Context, 283 query string) ( 284 driver.Stmt, error) { 285 logger.WithContext(sc.ctx).Infoln("Prepare") 286 if sc.rest == nil { 287 return nil, driver.ErrBadConn 288 } 289 stmt := &snowflakeStmt{ 290 sc: sc, 291 query: query, 292 } 293 return stmt, nil 294 } 295 296 func (sc *snowflakeConn) ExecContext( 297 ctx context.Context, 298 query string, 299 args []driver.NamedValue) ( 300 driver.Result, error) { 301 logger.WithContext(ctx).Infof("Exec: %#v, %v", query, args) 302 if sc.rest == nil { 303 return nil, driver.ErrBadConn 304 } 305 noResult := isAsyncMode(ctx) 306 isDesc := isDescribeOnly(ctx) 307 // TODO handle isInternal 308 ctx = setResultType(ctx, execResultType) 309 data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args) 310 if err != nil { 311 logger.WithContext(ctx).Infof("error: %v", err) 312 if data != nil { 313 code, e := strconv.Atoi(data.Code) 314 if e != nil { 315 return nil, e 316 } 317 return nil, (&SnowflakeError{ 318 Number: code, 319 SQLState: data.Data.SQLState, 320 Message: err.Error(), 321 QueryID: data.Data.QueryID, 322 }).exceptionTelemetry(sc) 323 } 324 return nil, err 325 } 326 327 // if async exec, return result object right away 328 if noResult { 329 return data.Data.AsyncResult, nil 330 } 331 332 if isDml(data.Data.StatementTypeID) { 333 // collects all values from the returned row sets 334 updatedRows, err := updateRows(data.Data) 335 if err != nil { 336 return nil, err 337 } 338 logger.WithContext(ctx).Debugf("number of updated rows: %#v", updatedRows) 339 return &snowflakeResult{ 340 affectedRows: updatedRows, 341 insertID: -1, 342 queryID: data.Data.QueryID, 343 }, nil // last insert id is not supported by Snowflake 344 } else if isMultiStmt(&data.Data) { 345 return sc.handleMultiExec(ctx, data.Data) 346 } else if isDql(&data.Data) { 347 logger.WithContext(ctx).Debugf("DQL") 348 if isStatementContext(ctx) { 349 return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil 350 } 351 return driver.ResultNoRows, nil 352 } 353 logger.Debug("DDL") 354 if isStatementContext(ctx) { 355 return &snowflakeResultNoRows{queryID: data.Data.QueryID}, nil 356 } 357 return driver.ResultNoRows, nil 358 } 359 360 func (sc *snowflakeConn) QueryContext( 361 ctx context.Context, 362 query string, 363 args []driver.NamedValue) ( 364 driver.Rows, error) { 365 qid, err := getResumeQueryID(ctx) 366 if err != nil { 367 return nil, err 368 } 369 if qid == "" { 370 return sc.queryContextInternal(ctx, query, args) 371 } 372 373 // check the query status to find out if there is a result to fetch 374 _, err = sc.checkQueryStatus(ctx, qid) 375 snowflakeErr, isSnowflakeError := err.(*SnowflakeError) 376 if err == nil || (isSnowflakeError && snowflakeErr.Number == ErrQueryIsRunning) { 377 // the query is running. Rows object will be returned from here. 378 return sc.buildRowsForRunningQuery(ctx, qid) 379 } 380 return nil, err 381 } 382 383 func (sc *snowflakeConn) queryContextInternal( 384 ctx context.Context, 385 query string, 386 args []driver.NamedValue) ( 387 driver.Rows, error) { 388 logger.WithContext(ctx).Infof("Query: %#v, %v", query, args) 389 if sc.rest == nil { 390 return nil, driver.ErrBadConn 391 } 392 393 noResult := isAsyncMode(ctx) 394 isDesc := isDescribeOnly(ctx) 395 ctx = setResultType(ctx, queryResultType) 396 // TODO: handle isInternal 397 data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args) 398 if err != nil { 399 logger.WithContext(ctx).Errorf("error: %v", err) 400 if data != nil { 401 code, e := strconv.Atoi(data.Code) 402 if e != nil { 403 return nil, e 404 } 405 return nil, (&SnowflakeError{ 406 Number: code, 407 SQLState: data.Data.SQLState, 408 Message: err.Error(), 409 QueryID: data.Data.QueryID, 410 }).exceptionTelemetry(sc) 411 } 412 return nil, err 413 } 414 415 // if async query, return row object right away 416 if noResult { 417 return data.Data.AsyncRows, nil 418 } 419 420 rows := new(snowflakeRows) 421 rows.sc = sc 422 rows.queryID = data.Data.QueryID 423 424 if isMultiStmt(&data.Data) { 425 // handleMultiQuery is responsible to fill rows with childResults 426 if err = sc.handleMultiQuery(ctx, data.Data, rows); err != nil { 427 return nil, err 428 } 429 } else { 430 rows.addDownloader(populateChunkDownloader(ctx, sc, data.Data)) 431 } 432 433 err = rows.ChunkDownloader.start() 434 return rows, err 435 } 436 437 func (sc *snowflakeConn) Prepare(query string) (driver.Stmt, error) { 438 return sc.PrepareContext(sc.ctx, query) 439 } 440 441 func (sc *snowflakeConn) Exec( 442 query string, 443 args []driver.Value) ( 444 driver.Result, error) { 445 return sc.ExecContext(sc.ctx, query, toNamedValues(args)) 446 } 447 448 func (sc *snowflakeConn) Query( 449 query string, 450 args []driver.Value) ( 451 driver.Rows, error) { 452 return sc.QueryContext(sc.ctx, query, toNamedValues(args)) 453 } 454 455 func (sc *snowflakeConn) Ping(ctx context.Context) error { 456 logger.WithContext(ctx).Infoln("Ping") 457 if sc.rest == nil { 458 return driver.ErrBadConn 459 } 460 noResult := isAsyncMode(ctx) 461 isDesc := isDescribeOnly(ctx) 462 // TODO: handle isInternal 463 ctx = setResultType(ctx, execResultType) 464 _, err := sc.exec(ctx, "SELECT 1", noResult, false, /* isInternal */ 465 isDesc, []driver.NamedValue{}) 466 return err 467 } 468 469 // CheckNamedValue determines which types are handled by this driver aside from 470 // the instances captured by driver.Value 471 func (sc *snowflakeConn) CheckNamedValue(nv *driver.NamedValue) error { 472 if supportedNullBind(nv) || supportedArrayBind(nv) { 473 return nil 474 } 475 return driver.ErrSkip 476 } 477 478 func (sc *snowflakeConn) GetQueryStatus( 479 ctx context.Context, 480 queryID string) ( 481 *SnowflakeQueryStatus, error) { 482 queryRet, err := sc.checkQueryStatus(ctx, queryID) 483 if err != nil { 484 return nil, err 485 } 486 return &SnowflakeQueryStatus{ 487 queryRet.SQLText, 488 queryRet.StartTime, 489 queryRet.EndTime, 490 queryRet.ErrorCode, 491 queryRet.ErrorMessage, 492 queryRet.Stats.ScanBytes, 493 queryRet.Stats.ProducedRows, 494 }, nil 495 } 496 497 // QueryArrowStream returns batches which can be queried for their raw arrow 498 // ipc stream of bytes. This way consumers don't need to be using the exact 499 // same version of Arrow as the connection is using internally in order 500 // to consume Arrow data. 501 func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bindings ...driver.NamedValue) (ArrowStreamLoader, error) { 502 ctx = WithArrowBatches(context.WithValue(ctx, asyncMode, false)) 503 ctx = setResultType(ctx, queryResultType) 504 isDesc := isDescribeOnly(ctx) 505 data, err := sc.exec(ctx, query, false, false /* isinternal */, isDesc, bindings) 506 if err != nil { 507 logger.WithContext(ctx).Errorf("error: %v", err) 508 if data != nil { 509 code, e := strconv.Atoi(data.Code) 510 if e != nil { 511 return nil, e 512 } 513 return nil, (&SnowflakeError{ 514 Number: code, 515 SQLState: data.Data.SQLState, 516 Message: err.Error(), 517 QueryID: data.Data.QueryID, 518 }).exceptionTelemetry(sc) 519 } 520 return nil, err 521 } 522 523 return &snowflakeArrowStreamChunkDownloader{ 524 sc: sc, 525 ChunkMetas: data.Data.Chunks, 526 Total: data.Data.Total, 527 Qrmk: data.Data.Qrmk, 528 ChunkHeader: data.Data.ChunkHeaders, 529 FuncGet: getChunk, 530 RowSet: rowSetType{ 531 RowType: data.Data.RowType, 532 JSON: data.Data.RowSet, 533 RowSetBase64: data.Data.RowSetBase64, 534 }, 535 }, nil 536 } 537 538 // ArrowStreamBatch is a type describing a potentially yet-to-be-downloaded 539 // Arrow IPC stream. Call `GetStream` to download and retrieve an io.Reader 540 // that can be used with ipc.NewReader to get record batch results. 541 type ArrowStreamBatch struct { 542 idx int 543 numrows int64 544 scd *snowflakeArrowStreamChunkDownloader 545 Loc *time.Location 546 rr io.ReadCloser 547 } 548 549 // NumRows returns the total number of rows that the metadata stated should 550 // be in this stream of record batches. 551 func (asb *ArrowStreamBatch) NumRows() int64 { return asb.numrows } 552 553 // gzip.Reader.Close does NOT close the underlying reader, so we 554 // need to wrap with wrapReader so that closing will close the 555 // response body (or any other reader that we want to gzip uncompress) 556 type wrapReader struct { 557 io.Reader 558 wrapped io.ReadCloser 559 } 560 561 func (w *wrapReader) Close() error { 562 if cl, ok := w.Reader.(io.ReadCloser); ok { 563 if err := cl.Close(); err != nil { 564 return err 565 } 566 } 567 return w.wrapped.Close() 568 } 569 570 func (asb *ArrowStreamBatch) downloadChunkStreamHelper(ctx context.Context) error { 571 headers := make(map[string]string) 572 if len(asb.scd.ChunkHeader) > 0 { 573 logger.Debug("chunk header is provided") 574 for k, v := range asb.scd.ChunkHeader { 575 logger.Debugf("adding header: %v, value: %v", k, v) 576 577 headers[k] = v 578 } 579 } else { 580 headers[headerSseCAlgorithm] = headerSseCAes 581 headers[headerSseCKey] = asb.scd.Qrmk 582 } 583 584 resp, err := asb.scd.FuncGet(ctx, asb.scd.sc, asb.scd.ChunkMetas[asb.idx].URL, headers, asb.scd.sc.rest.RequestTimeout) 585 if err != nil { 586 return err 587 } 588 logger.Debugf("response returned chunk: %v for URL: %v", asb.idx+1, asb.scd.ChunkMetas[asb.idx].URL) 589 if resp.StatusCode != http.StatusOK { 590 defer resp.Body.Close() 591 b, err := io.ReadAll(resp.Body) 592 if err != nil { 593 return err 594 } 595 596 logger.Infof("HTTP: %v, URL: %v, Body: %v", resp.StatusCode, asb.scd.ChunkMetas[asb.idx].URL, b) 597 logger.Infof("Header: %v", resp.Header) 598 return &SnowflakeError{ 599 Number: ErrFailedToGetChunk, 600 SQLState: SQLStateConnectionFailure, 601 Message: errMsgFailedToGetChunk, 602 MessageArgs: []interface{}{asb.idx}, 603 } 604 } 605 606 defer func() { 607 if asb.rr == nil { 608 resp.Body.Close() 609 } 610 }() 611 612 bufStream := bufio.NewReader(resp.Body) 613 gzipMagic, err := bufStream.Peek(2) 614 if err != nil { 615 return err 616 } 617 618 if gzipMagic[0] == 0x1f && gzipMagic[1] == 0x8b { 619 // detect and uncompress gzip 620 bufStream0, err := gzip.NewReader(bufStream) 621 if err != nil { 622 return err 623 } 624 // gzip.Reader.Close() does NOT close the underlying 625 // reader, so we need to wrap it and ensure close will 626 // close the response body. Otherwise we'll leak it. 627 asb.rr = &wrapReader{Reader: bufStream0, wrapped: resp.Body} 628 } else { 629 asb.rr = &wrapReader{Reader: bufStream, wrapped: resp.Body} 630 } 631 return nil 632 } 633 634 // GetStream returns a stream of bytes consisting of an Arrow IPC Record 635 // batch stream. Close should be called on the returned stream when done 636 // to ensure no leaked memory. 637 func (asb *ArrowStreamBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { 638 if asb.rr == nil { 639 if err := asb.downloadChunkStreamHelper(ctx); err != nil { 640 return nil, err 641 } 642 } 643 644 return asb.rr, nil 645 } 646 647 // ArrowStreamLoader is a convenience interface for downloading 648 // Snowflake results via multiple Arrow Record Batch streams. 649 // 650 // Some queries from Snowflake do not return Arrow data regardless 651 // of the settings, such as "SHOW WAREHOUSES". In these cases, 652 // you'll find TotalRows() > 0 but GetBatches returns no batches 653 // and no errors. In this case, the data is accessible via JSONData 654 // with the actual types matching up to the metadata in RowTypes. 655 type ArrowStreamLoader interface { 656 GetBatches() ([]ArrowStreamBatch, error) 657 TotalRows() int64 658 RowTypes() []execResponseRowType 659 Location() *time.Location 660 JSONData() [][]*string 661 } 662 663 type snowflakeArrowStreamChunkDownloader struct { 664 sc *snowflakeConn 665 ChunkMetas []execResponseChunk 666 Total int64 667 Qrmk string 668 ChunkHeader map[string]string 669 FuncGet func(context.Context, *snowflakeConn, string, map[string]string, time.Duration) (*http.Response, error) 670 RowSet rowSetType 671 } 672 673 func (scd *snowflakeArrowStreamChunkDownloader) Location() *time.Location { 674 if scd.sc != nil { 675 return getCurrentLocation(scd.sc.cfg.Params) 676 } 677 return nil 678 } 679 func (scd *snowflakeArrowStreamChunkDownloader) TotalRows() int64 { return scd.Total } 680 func (scd *snowflakeArrowStreamChunkDownloader) RowTypes() []execResponseRowType { 681 return scd.RowSet.RowType 682 } 683 func (scd *snowflakeArrowStreamChunkDownloader) JSONData() [][]*string { 684 return scd.RowSet.JSON 685 } 686 687 // the server might have had an empty first batch, check if we can decode 688 // that first batch, if not we skip it. 689 func (scd *snowflakeArrowStreamChunkDownloader) maybeFirstBatch() []byte { 690 if scd.RowSet.RowSetBase64 == "" { 691 return nil 692 } 693 694 // first batch 695 rowSetBytes, err := base64.StdEncoding.DecodeString(scd.RowSet.RowSetBase64) 696 if err != nil { 697 // match logic in buildFirstArrowChunk 698 // assume there's no first chunk if we can't decode the base64 string 699 return nil 700 } 701 702 // verify it's a valid ipc stream, otherwise skip it 703 rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes)) 704 if err != nil { 705 return nil 706 } 707 rr.Release() 708 709 return rowSetBytes 710 } 711 712 func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamBatch, err error) { 713 chunkMetaLen := len(scd.ChunkMetas) 714 loc := scd.Location() 715 716 out = make([]ArrowStreamBatch, chunkMetaLen, chunkMetaLen+1) 717 toFill := out 718 rowSetBytes := scd.maybeFirstBatch() 719 // if there was no first batch in the response from the server, 720 // skip it and move on. toFill == out 721 // otherwise expand out by one to account for the first batch 722 // and fill it in. have toFill refer to the slice of out excluding 723 // the first batch. 724 if len(rowSetBytes) > 0 { 725 out = out[:chunkMetaLen+1] 726 out[0] = ArrowStreamBatch{ 727 scd: scd, 728 Loc: loc, 729 rr: io.NopCloser(bytes.NewReader(rowSetBytes)), 730 } 731 toFill = out[1:] 732 } 733 734 var totalCounted int64 735 for i := range toFill { 736 toFill[i] = ArrowStreamBatch{ 737 idx: i, 738 numrows: int64(scd.ChunkMetas[i].RowCount), 739 Loc: loc, 740 scd: scd, 741 } 742 totalCounted += int64(scd.ChunkMetas[i].RowCount) 743 } 744 745 if len(rowSetBytes) > 0 { 746 // if we had a first batch, fill in the numrows 747 out[0].numrows = scd.Total - totalCounted 748 } 749 return 750 } 751 752 func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) { 753 sc := &snowflakeConn{ 754 SequenceCounter: 0, 755 ctx: ctx, 756 cfg: &config, 757 queryContextCache: (&queryContextCache{}).init(), 758 currentTimeProvider: defaultTimeProvider, 759 } 760 err := initEasyLogging(config.ClientConfigFile) 761 if err != nil { 762 return nil, err 763 } 764 var st http.RoundTripper = SnowflakeTransport 765 if sc.cfg.Transporter == nil { 766 if sc.cfg.InsecureMode { 767 // no revocation check with OCSP. Think twice when you want to enable this option. 768 st = snowflakeInsecureTransport 769 } else { 770 // set OCSP fail open mode 771 ocspResponseCacheLock.Lock() 772 atomic.StoreUint32((*uint32)(&ocspFailOpen), uint32(sc.cfg.OCSPFailOpen)) 773 ocspResponseCacheLock.Unlock() 774 } 775 } else { 776 // use the custom transport 777 st = sc.cfg.Transporter 778 } 779 if strings.HasSuffix(sc.cfg.Host, privateLinkSuffix) { 780 if err := sc.setupOCSPPrivatelink(sc.cfg.Application, sc.cfg.Host); err != nil { 781 return nil, err 782 } 783 } else { 784 if _, set := os.LookupEnv(cacheServerURLEnv); set { 785 os.Unsetenv(cacheServerURLEnv) 786 } 787 } 788 var tokenAccessor TokenAccessor 789 if sc.cfg.TokenAccessor != nil { 790 tokenAccessor = sc.cfg.TokenAccessor 791 } else { 792 tokenAccessor = getSimpleTokenAccessor() 793 } 794 795 // authenticate 796 sc.rest = &snowflakeRestful{ 797 Host: sc.cfg.Host, 798 Port: sc.cfg.Port, 799 Protocol: sc.cfg.Protocol, 800 Client: &http.Client{ 801 // request timeout including reading response body 802 Timeout: sc.cfg.ClientTimeout, 803 Transport: st, 804 }, 805 JWTClient: &http.Client{ 806 Timeout: sc.cfg.JWTClientTimeout, 807 Transport: st, 808 }, 809 TokenAccessor: tokenAccessor, 810 LoginTimeout: sc.cfg.LoginTimeout, 811 RequestTimeout: sc.cfg.RequestTimeout, 812 MaxRetryCount: sc.cfg.MaxRetryCount, 813 FuncPost: postRestful, 814 FuncGet: getRestful, 815 FuncAuthPost: postAuthRestful, 816 FuncPostQuery: postRestfulQuery, 817 FuncPostQueryHelper: postRestfulQueryHelper, 818 FuncRenewSession: renewRestfulSession, 819 FuncPostAuth: postAuth, 820 FuncCloseSession: closeSession, 821 FuncCancelQuery: cancelQuery, 822 FuncPostAuthSAML: postAuthSAML, 823 FuncPostAuthOKTA: postAuthOKTA, 824 FuncGetSSO: getSSO, 825 } 826 827 if sc.cfg.DisableTelemetry { 828 sc.telemetry = &snowflakeTelemetry{enabled: false} 829 } else { 830 sc.telemetry = &snowflakeTelemetry{ 831 flushSize: defaultFlushSize, 832 sr: sc.rest, 833 mutex: &sync.Mutex{}, 834 enabled: true, 835 } 836 } 837 838 return sc, nil 839 }