github.com/snowflakedb/gosnowflake@v1.9.0/arrow_chunk.go (about)

     1  // Copyright (c) 2020-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/base64"
     8  	"time"
     9  
    10  	"github.com/apache/arrow/go/v15/arrow"
    11  	"github.com/apache/arrow/go/v15/arrow/ipc"
    12  	"github.com/apache/arrow/go/v15/arrow/memory"
    13  )
    14  
    15  type arrowResultChunk struct {
    16  	reader    *ipc.Reader
    17  	rowCount  int
    18  	loc       *time.Location
    19  	allocator memory.Allocator
    20  }
    21  
    22  func (arc *arrowResultChunk) decodeArrowChunk(rowType []execResponseRowType, highPrec bool) ([]chunkRowType, error) {
    23  	logger.Debug("Arrow Decoder")
    24  	var chunkRows []chunkRowType
    25  
    26  	for arc.reader.Next() {
    27  		record := arc.reader.Record()
    28  
    29  		start := len(chunkRows)
    30  		numRows := int(record.NumRows())
    31  		columns := record.Columns()
    32  		chunkRows = append(chunkRows, make([]chunkRowType, numRows)...)
    33  		for i := start; i < start+numRows; i++ {
    34  			chunkRows[i].ArrowRow = make([]snowflakeValue, len(columns))
    35  		}
    36  
    37  		for colIdx, col := range columns {
    38  			values := make([]snowflakeValue, numRows)
    39  			if err := arrowToValue(values, rowType[colIdx], col, arc.loc, highPrec); err != nil {
    40  				return nil, err
    41  			}
    42  
    43  			for i := range values {
    44  				chunkRows[start+i].ArrowRow[colIdx] = values[i]
    45  			}
    46  		}
    47  		arc.rowCount += numRows
    48  	}
    49  
    50  	return chunkRows, arc.reader.Err()
    51  }
    52  
    53  func (arc *arrowResultChunk) decodeArrowBatch(scd *snowflakeChunkDownloader) (*[]arrow.Record, error) {
    54  	var records []arrow.Record
    55  	defer arc.reader.Release()
    56  
    57  	for arc.reader.Next() {
    58  		rawRecord := arc.reader.Record()
    59  
    60  		record, err := arrowToRecord(scd.ctx, rawRecord, arc.allocator, scd.RowSet.RowType, arc.loc)
    61  		if err != nil {
    62  			return nil, err
    63  		}
    64  		records = append(records, record)
    65  	}
    66  
    67  	return &records, arc.reader.Err()
    68  }
    69  
    70  // Build arrow chunk based on RowSet of base64
    71  func buildFirstArrowChunk(rowsetBase64 string, loc *time.Location, alloc memory.Allocator) (arrowResultChunk, error) {
    72  	rowSetBytes, err := base64.StdEncoding.DecodeString(rowsetBase64)
    73  	if err != nil {
    74  		return arrowResultChunk{}, err
    75  	}
    76  	rr, err := ipc.NewReader(bytes.NewReader(rowSetBytes), ipc.WithAllocator(alloc))
    77  	if err != nil {
    78  		return arrowResultChunk{}, err
    79  	}
    80  
    81  	return arrowResultChunk{rr, 0, loc, alloc}, nil
    82  }