github.com/snowflakedb/gosnowflake@v1.9.0/async_test.go (about)

     1  // Copyright (c) 2021-2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"fmt"
     9  	"testing"
    10  )
    11  
    12  func TestAsyncMode(t *testing.T) {
    13  	ctx := WithAsyncMode(context.Background())
    14  	numrows := 100000
    15  	cnt := 0
    16  	var idx int
    17  	var v string
    18  
    19  	runDBTest(t, func(dbt *DBTest) {
    20  		rows := dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
    21  		defer rows.Close()
    22  
    23  		// Next() will block and wait until results are available
    24  		for rows.Next() {
    25  			if err := rows.Scan(&idx, &v); err != nil {
    26  				t.Fatal(err)
    27  			}
    28  			cnt++
    29  		}
    30  		logger.Infof("NextResultSet: %v", rows.NextResultSet())
    31  
    32  		if cnt != numrows {
    33  			t.Errorf("number of rows didn't match. expected: %v, got: %v", numrows, cnt)
    34  		}
    35  
    36  		dbt.mustExec("create or replace table test_async_exec (value boolean)")
    37  		res := dbt.mustExecContext(ctx, "insert into test_async_exec values (true)")
    38  		count, err := res.RowsAffected()
    39  		if err != nil {
    40  			t.Fatalf("res.RowsAffected() returned error: %v", err)
    41  		}
    42  		if count != 1 {
    43  			t.Fatalf("expected 1 affected row, got %d", count)
    44  		}
    45  	})
    46  }
    47  
    48  func TestAsyncModePing(t *testing.T) {
    49  	ctx := WithAsyncMode(context.Background())
    50  
    51  	runDBTest(t, func(dbt *DBTest) {
    52  		defer func() {
    53  			if r := recover(); r != nil {
    54  				t.Fatalf("panic during ping: %v", r)
    55  			}
    56  		}()
    57  		err := dbt.conn.PingContext(ctx)
    58  		if err != nil {
    59  			t.Fatal(err)
    60  		}
    61  	})
    62  }
    63  
    64  func TestAsyncModeMultiStatement(t *testing.T) {
    65  	withMultiStmtCtx, _ := WithMultiStatement(context.Background(), 6)
    66  	ctx := WithAsyncMode(withMultiStmtCtx)
    67  	multiStmtQuery := "begin;\n" +
    68  		"delete from test_multi_statement_async;\n" +
    69  		"insert into test_multi_statement_async values (1, 'a'), (2, 'b');\n" +
    70  		"select 1;\n" +
    71  		"select 2;\n" +
    72  		"rollback;"
    73  
    74  	runDBTest(t, func(dbt *DBTest) {
    75  		dbt.mustExec("drop table if exists test_multi_statement_async")
    76  		dbt.mustExec(`create or replace table test_multi_statement_async(
    77  			c1 number, c2 string) as select 10, 'z'`)
    78  		defer dbt.mustExec("drop table if exists test_multi_statement_async")
    79  
    80  		res := dbt.mustExecContext(ctx, multiStmtQuery)
    81  		count, err := res.RowsAffected()
    82  		if err != nil {
    83  			t.Fatalf("res.RowsAffected() returned error: %v", err)
    84  		}
    85  		if count != 3 {
    86  			t.Fatalf("expected 3 affected rows, got %d", count)
    87  		}
    88  	})
    89  }
    90  
    91  func TestAsyncModeCancel(t *testing.T) {
    92  	withCancelCtx, cancel := context.WithCancel(context.Background())
    93  	ctx := WithAsyncMode(withCancelCtx)
    94  	numrows := 100000
    95  
    96  	runDBTest(t, func(dbt *DBTest) {
    97  		dbt.mustQueryContext(ctx, fmt.Sprintf(selectRandomGenerator, numrows))
    98  		cancel()
    99  	})
   100  }
   101  
   102  func TestAsyncQueryFail(t *testing.T) {
   103  	ctx := WithAsyncMode(context.Background())
   104  	runDBTest(t, func(dbt *DBTest) {
   105  		rows := dbt.mustQueryContext(ctx, "selectt 1")
   106  		defer rows.Close()
   107  
   108  		if rows.Next() {
   109  			t.Fatal("should have no rows available")
   110  		} else {
   111  			if err := rows.Err(); err == nil {
   112  				t.Fatal("should return a syntax error")
   113  			}
   114  		}
   115  	})
   116  }
   117  
   118  func TestMultipleAsyncQueries(t *testing.T) {
   119  	ctx := WithAsyncMode(context.Background())
   120  	s1 := "foo"
   121  	s2 := "bar"
   122  	ch1 := make(chan string)
   123  	ch2 := make(chan string)
   124  
   125  	db := openDB(t)
   126  
   127  	runDBTest(t, func(dbt *DBTest) {
   128  		rows1, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s1, 30))
   129  		if err != nil {
   130  			t.Fatalf("can't read rows1: %v", err)
   131  		}
   132  		defer rows1.Close()
   133  		rows2, err := db.QueryContext(ctx, fmt.Sprintf("select distinct '%v' from table (generator(timelimit=>%v))", s2, 10))
   134  		if err != nil {
   135  			t.Fatalf("can't read rows2: %v", err)
   136  		}
   137  		defer rows2.Close()
   138  
   139  		go retrieveRows(rows1, ch1)
   140  		go retrieveRows(rows2, ch2)
   141  		select {
   142  		case res := <-ch1:
   143  			t.Fatalf("value %v should not have been called earlier.", res)
   144  		case res := <-ch2:
   145  			if res != s2 {
   146  				t.Fatalf("query failed. expected: %v, got: %v", s2, res)
   147  			}
   148  		}
   149  	})
   150  }
   151  
   152  func retrieveRows(rows *sql.Rows, ch chan string) {
   153  	var s string
   154  	for rows.Next() {
   155  		if err := rows.Scan(&s); err != nil {
   156  			ch <- err.Error()
   157  			close(ch)
   158  			return
   159  		}
   160  	}
   161  	ch <- s
   162  	close(ch)
   163  }
   164  
   165  func TestLongRunningAsyncQuery(t *testing.T) {
   166  	conn := openConn(t)
   167  	defer conn.Close()
   168  
   169  	ctx, _ := WithMultiStatement(context.Background(), 0)
   170  	query := "CALL SYSTEM$WAIT(50, 'SECONDS');use snowflake_sample_data"
   171  
   172  	rows, err := conn.QueryContext(WithAsyncMode(ctx), query)
   173  	if err != nil {
   174  		t.Fatalf("failed to run a query. %v, err: %v", query, err)
   175  	}
   176  	defer rows.Close()
   177  	var v string
   178  	i := 0
   179  	for {
   180  		for rows.Next() {
   181  			err := rows.Scan(&v)
   182  			if err != nil {
   183  				t.Fatalf("failed to get result. err: %v", err)
   184  			}
   185  			if v == "" {
   186  				t.Fatal("should have returned a result")
   187  			}
   188  			results := []string{"waited 50 seconds", "Statement executed successfully."}
   189  			if v != results[i] {
   190  				t.Fatalf("unexpected result returned. expected: %v, but got: %v", results[i], v)
   191  			}
   192  			i++
   193  		}
   194  		if !rows.NextResultSet() {
   195  			break
   196  		}
   197  	}
   198  }
   199  
   200  func TestLongRunningAsyncQueryFetchResultByID(t *testing.T) {
   201  	runDBTest(t, func(dbt *DBTest) {
   202  		queryIDChan := make(chan string, 1)
   203  		ctx := WithAsyncMode(context.Background())
   204  		ctx = WithQueryIDChan(ctx, queryIDChan)
   205  
   206  		// Run a long running query asynchronously
   207  		go dbt.mustExecContext(ctx, "CALL SYSTEM$WAIT(50, 'SECONDS')")
   208  
   209  		// Get the query ID without waiting for the query to finish
   210  		queryID := <-queryIDChan
   211  		assertNotNilF(t, queryID, "expected a nonempty query ID")
   212  
   213  		ctx = WithFetchResultByID(ctx, queryID)
   214  		rows := dbt.mustQueryContext(ctx, "")
   215  		defer rows.Close()
   216  
   217  		var v string
   218  		assertTrueF(t, rows.Next())
   219  		err := rows.Scan(&v)
   220  		assertNilF(t, err, fmt.Sprintf("failed to get result. err: %v", err))
   221  		assertNotNilF(t, v, "should have returned a result")
   222  
   223  		expected := "waited 50 seconds"
   224  		if v != expected {
   225  			t.Fatalf("unexpected result returned. expected: %v, but got: %v", expected, v)
   226  		}
   227  		assertFalseF(t, rows.NextResultSet())
   228  	})
   229  }