github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/tests/integration/table_multiple_result_sets_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package integration
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"os"
    10  	"path"
    11  	"testing"
    12  
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/ydb-platform/ydb-go-sdk/v3"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/table/options"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    21  )
    22  
    23  type scopeTableStreamExecuteScanQuery struct {
    24  	folder          string
    25  	tableName       string
    26  	upsertRowsCount int
    27  
    28  	sum uint64
    29  }
    30  
    31  func TestTableMultipleResultSets(t *testing.T) {
    32  	var (
    33  		scope = &scopeTableStreamExecuteScanQuery{
    34  			folder:          t.Name(),
    35  			tableName:       "stream_query_table",
    36  			upsertRowsCount: 100000,
    37  			sum:             0,
    38  		}
    39  		ctx = xtest.Context(t)
    40  	)
    41  
    42  	db, err := ydb.Open(ctx,
    43  		"", // corner case for check replacement of endpoint+database+secure
    44  		ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")),
    45  		ydb.WithLogger(
    46  			newLogger(t),
    47  			trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`),
    48  		),
    49  	)
    50  	require.NoError(t, err)
    51  
    52  	defer func() {
    53  		err = db.Close(ctx)
    54  		require.NoError(t, err)
    55  	}()
    56  
    57  	t.Run("create", func(t *testing.T) {
    58  		t.Run("table", func(t *testing.T) {
    59  			err = db.Table().Do(ctx,
    60  				func(ctx context.Context, s table.Session) (err error) {
    61  					_ = s.ExecuteSchemeQuery(
    62  						ctx, `
    63  						PRAGMA TablePathPrefix("`+path.Join(db.Name(), scope.folder)+`");
    64  						DROP TABLE `+scope.tableName+`;`,
    65  					)
    66  					return s.ExecuteSchemeQuery(
    67  						ctx, `
    68  						PRAGMA TablePathPrefix("`+path.Join(db.Name(), scope.folder)+`");
    69  						CREATE TABLE `+scope.tableName+` (val Int32, PRIMARY KEY (val));`,
    70  					)
    71  				},
    72  				table.WithIdempotent(),
    73  			)
    74  			require.NoError(t, err)
    75  		})
    76  	})
    77  
    78  	t.Run("upsert", func(t *testing.T) {
    79  		t.Run("data", func(t *testing.T) {
    80  			// - upsert data
    81  			values := make([]types.Value, 0, scope.upsertRowsCount)
    82  			for i := 0; i < scope.upsertRowsCount; i++ {
    83  				scope.sum += uint64(i)
    84  				values = append(
    85  					values,
    86  					types.StructValue(
    87  						types.StructFieldValue("val", types.Int32Value(int32(i))),
    88  					),
    89  				)
    90  			}
    91  			err := db.Table().Do(ctx,
    92  				func(ctx context.Context, s table.Session) (err error) {
    93  					_, _, err = s.Execute(ctx,
    94  						table.TxControl(
    95  							table.BeginTx(
    96  								table.WithSerializableReadWrite(),
    97  							),
    98  							table.CommitTx(),
    99  						), `
   100  							PRAGMA TablePathPrefix("`+path.Join(db.Name(), scope.folder)+`");
   101  							DECLARE $values AS List<Struct<
   102  								val: Int32,
   103  							> >;
   104  							UPSERT INTO `+scope.tableName+`
   105  							SELECT
   106  								val 
   107  							FROM
   108  								AS_TABLE($values);            
   109  						`, table.NewQueryParameters(
   110  							table.ValueParam(
   111  								"$values",
   112  								types.ListValue(values...),
   113  							),
   114  						),
   115  					)
   116  					return err
   117  				},
   118  				table.WithIdempotent(),
   119  			)
   120  			require.NoError(t, err)
   121  		})
   122  	})
   123  
   124  	t.Run("scan", func(t *testing.T) {
   125  		t.Run("scan", func(t *testing.T) {
   126  			err := db.Table().Do(ctx,
   127  				func(ctx context.Context, s table.Session) (err error) {
   128  					res, err := s.StreamExecuteScanQuery(
   129  						ctx, `
   130  							PRAGMA TablePathPrefix("`+path.Join(db.Name(), scope.folder)+`");
   131  							SELECT val FROM `+scope.tableName+`;`,
   132  						table.NewQueryParameters(),
   133  						options.WithExecuteScanQueryStats(options.ExecuteScanQueryStatsTypeFull),
   134  					)
   135  					if err != nil {
   136  						return err
   137  					}
   138  					var (
   139  						resultSetsCount = 0
   140  						rowsCount       = 0
   141  						checkSum        uint64
   142  					)
   143  					for res.NextResultSet(ctx) {
   144  						resultSetsCount++
   145  						for res.NextRow() {
   146  							rowsCount++
   147  							var val *int32
   148  							err = res.Scan(&val)
   149  							if err != nil {
   150  								return err
   151  							}
   152  							checkSum += uint64(*val)
   153  						}
   154  						if stats := res.Stats(); stats != nil {
   155  							t.Logf(" --- query stats: compilation: %v, process CPU time: %v, affected shards: %v\n",
   156  								stats.Compilation(),
   157  								stats.ProcessCPUTime(),
   158  								func() (count uint64) {
   159  									for {
   160  										phase, ok := stats.NextPhase()
   161  										if !ok {
   162  											return
   163  										}
   164  										count += phase.AffectedShards()
   165  									}
   166  								}(),
   167  							)
   168  						}
   169  					}
   170  
   171  					if err = res.Err(); err != nil {
   172  						return err
   173  					}
   174  
   175  					if rowsCount != scope.upsertRowsCount {
   176  						return fmt.Errorf("wrong rows count: %v, exp: %v", rowsCount, scope.upsertRowsCount)
   177  					}
   178  
   179  					if scope.sum != checkSum {
   180  						return fmt.Errorf("wrong checkSum: %v, exp: %v", checkSum, scope.sum)
   181  					}
   182  
   183  					if resultSetsCount <= 1 {
   184  						return fmt.Errorf("wrong result sets count: %v", resultSetsCount)
   185  					}
   186  
   187  					return nil
   188  				},
   189  				table.WithIdempotent(),
   190  			)
   191  			require.NoError(t, err)
   192  		})
   193  	})
   194  }