github.com/snowflakedb/gosnowflake@v1.9.0/cmd/arrow/batches/arrow_batches.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"flag"
     8  	"fmt"
     9  	"github.com/apache/arrow/go/v15/arrow"
    10  	"github.com/apache/arrow/go/v15/arrow/array"
    11  	"github.com/apache/arrow/go/v15/arrow/memory"
    12  	"log"
    13  	"sync"
    14  	"time"
    15  
    16  	sf "github.com/snowflakedb/gosnowflake"
    17  )
    18  
    19  type sampleRecord struct {
    20  	batchID  int
    21  	workerID int
    22  	number   int32
    23  	string   string
    24  	ts       *time.Time
    25  }
    26  
    27  func (s sampleRecord) String() string {
    28  	return fmt.Sprintf("batchID: %v, workerID: %v, number: %v, string: %v, ts: %v", s.batchID, s.workerID, s.number, s.string, s.ts)
    29  }
    30  
    31  func main() {
    32  	if !flag.Parsed() {
    33  		flag.Parse()
    34  	}
    35  
    36  	cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{
    37  		{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
    38  		{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
    39  		{Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true},
    40  		{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
    41  		{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
    42  		{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
    43  	})
    44  	if err != nil {
    45  		log.Fatalf("failed to create Config, err: %v", err)
    46  	}
    47  
    48  	dsn, err := sf.DSN(cfg)
    49  	if err != nil {
    50  		log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err)
    51  	}
    52  
    53  	ctx :=
    54  		sf.WithArrowBatchesTimestampOption(
    55  			sf.WithArrowAllocator(
    56  				sf.WithArrowBatches(context.Background()), memory.DefaultAllocator), sf.UseOriginalTimestamp)
    57  
    58  	query := "SELECT SEQ4(), 'example ' || (SEQ4() * 2), " +
    59  		" TO_TIMESTAMP_NTZ('9999-01-01 13:13:13.' || LPAD(SEQ4(),9,'0'))  ltz " +
    60  		" FROM TABLE(GENERATOR(ROWCOUNT=>30000))"
    61  
    62  	db, err := sql.Open("snowflake", dsn)
    63  	if err != nil {
    64  		log.Fatalf("failed to connect. %v, err: %v", dsn, err)
    65  	}
    66  	defer db.Close()
    67  
    68  	conn, _ := db.Conn(ctx)
    69  	defer conn.Close()
    70  
    71  	var rows driver.Rows
    72  	err = conn.Raw(func(x interface{}) error {
    73  		rows, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil)
    74  		return err
    75  	})
    76  	if err != nil {
    77  		log.Fatalf("unable to run the query. err: %v", err)
    78  	}
    79  	defer rows.Close()
    80  
    81  	batches, err := rows.(sf.SnowflakeRows).GetArrowBatches()
    82  	batchIds := make(chan int, 1)
    83  	maxWorkers := len(batches)
    84  	sampleRecordsPerBatch := make([][]sampleRecord, len(batches))
    85  
    86  	var waitGroup sync.WaitGroup
    87  	for workerID := 0; workerID < maxWorkers; workerID++ {
    88  		waitGroup.Add(1)
    89  		go func(waitGroup *sync.WaitGroup, batchIDs chan int, workerId int) {
    90  			defer waitGroup.Done()
    91  
    92  			for batchID := range batchIDs {
    93  				records, err := batches[batchID].Fetch()
    94  				if err != nil {
    95  					log.Fatalf("Error while fetching batch %v: %v", batchID, err)
    96  				}
    97  				sampleRecordsPerBatch[batchID] = make([]sampleRecord, batches[batchID].GetRowCount())
    98  				totalRowID := 0
    99  				convertFromColumnsToRows(records, sampleRecordsPerBatch, batchID, workerId, totalRowID, batches[batchID])
   100  			}
   101  		}(&waitGroup, batchIds, workerID)
   102  	}
   103  
   104  	for batchID := 0; batchID < len(batches); batchID++ {
   105  		batchIds <- batchID
   106  	}
   107  	close(batchIds)
   108  	waitGroup.Wait()
   109  
   110  	for _, batchSampleRecords := range sampleRecordsPerBatch {
   111  		for _, sampleRecord := range batchSampleRecords {
   112  			fmt.Println(sampleRecord)
   113  		}
   114  	}
   115  	for batchID, batch := range batches {
   116  		fmt.Printf("BatchId: %v, number of records: %v\n", batchID, batch.GetRowCount())
   117  	}
   118  }
   119  
   120  func convertFromColumnsToRows(records *[]arrow.Record, sampleRecordsPerBatch [][]sampleRecord, batchID int,
   121  	workerID int, totalRowID int, batch *sf.ArrowBatch) {
   122  	for _, record := range *records {
   123  		for rowID, intColumn := range record.Column(0).(*array.Int32).Int32Values() {
   124  			sampleRecord := sampleRecord{
   125  				batchID:  batchID,
   126  				workerID: workerID,
   127  				number:   intColumn,
   128  				string:   record.Column(1).(*array.String).Value(rowID),
   129  				ts:       batch.ArrowSnowflakeTimestampToTime(record, 2, rowID),
   130  			}
   131  			sampleRecordsPerBatch[batchID][totalRowID] = sampleRecord
   132  			totalRowID++
   133  		}
   134  	}
   135  }