github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/tests/integration/query_execute_script_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package integration
     5  
     6  import (
     7  	"errors"
     8  	"io"
     9  	"os"
    10  	"path"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/ydb-platform/ydb-go-sdk/v3"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/version"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/query"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    21  )
    22  
    23  func TestQueryExecuteScript(sourceTest *testing.T) {
    24  	if version.Lt(os.Getenv("YDB_VERSION"), "24.1") {
    25  		sourceTest.Skip("query service not allowed in YDB version '" + os.Getenv("YDB_VERSION") + "'")
    26  	}
    27  
    28  	t := xtest.MakeSyncedTest(sourceTest)
    29  	var (
    30  		folder           = t.Name()
    31  		tableName        = `test`
    32  		db               *ydb.Driver
    33  		err              error
    34  		upsertRowsCount  = 100000
    35  		batchSize        = 10000
    36  		expectedCheckSum = uint64(4999950000)
    37  		ctx              = xtest.Context(t)
    38  	)
    39  
    40  	db, err = ydb.Open(ctx,
    41  		os.Getenv("YDB_CONNECTION_STRING"),
    42  		ydb.WithAccessTokenCredentials(
    43  			os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"),
    44  		),
    45  	)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	defer func(db *ydb.Driver) {
    50  		// cleanup
    51  		_ = db.Close(ctx)
    52  	}(db)
    53  
    54  	err = db.Query().Exec(ctx,
    55  		"CREATE TABLE IF NOT EXISTS `"+path.Join(db.Name(), folder, tableName)+"` (val Int64, PRIMARY KEY (val))",
    56  	)
    57  	require.NoError(t, err)
    58  
    59  	require.Zero(t, upsertRowsCount%batchSize, "wrong batch size: (%d mod %d = %d) != 0", upsertRowsCount, batchSize, upsertRowsCount%batchSize)
    60  
    61  	var upserted uint32
    62  	for i := 0; i < (upsertRowsCount / batchSize); i++ {
    63  		var (
    64  			from = int32(i * batchSize)
    65  			to   = int32((i + 1) * batchSize)
    66  		)
    67  		t.Logf("upserting rows %d..%d\n", from, to-1)
    68  		values := make([]types.Value, 0, batchSize)
    69  		for j := from; j < to; j++ {
    70  			values = append(
    71  				values,
    72  				types.StructValue(
    73  					types.StructFieldValue("val", types.Int32Value(j)),
    74  				),
    75  			)
    76  		}
    77  		err := db.Query().Exec(ctx, `
    78  						DECLARE $values AS List<Struct<
    79  							val: Int32,
    80  						>>;
    81  						UPSERT INTO `+"`"+path.Join(db.Name(), folder, tableName)+"`"+`
    82  						SELECT
    83  							val 
    84  						FROM
    85  							AS_TABLE($values);            
    86  					`, query.WithParameters(
    87  			ydb.ParamsBuilder().Param("$values").BeginList().AddItems(values...).EndList().Build(),
    88  		),
    89  		)
    90  		require.NoError(t, err)
    91  		upserted += uint32(to - from)
    92  	}
    93  	require.Equal(t, uint32(upsertRowsCount), upserted)
    94  
    95  	row, err := db.Query().QueryRow(ctx,
    96  		"SELECT CAST(COUNT(*) AS Uint64) FROM `"+path.Join(db.Name(), folder, tableName)+"`;",
    97  	)
    98  	require.NoError(t, err)
    99  	var rowsFromDb uint64
   100  	err = row.Scan(&rowsFromDb)
   101  	require.NoError(t, err)
   102  	require.Equal(t, uint64(upsertRowsCount), rowsFromDb)
   103  
   104  	row, err = db.Query().QueryRow(ctx,
   105  		"SELECT CAST(SUM(val) AS Uint64) FROM `"+path.Join(db.Name(), folder, tableName)+"`;",
   106  	)
   107  	require.NoError(t, err)
   108  	var checkSumFromDb uint64
   109  	err = row.Scan(&checkSumFromDb)
   110  	require.NoError(t, err)
   111  	require.Equal(t, expectedCheckSum, checkSumFromDb)
   112  
   113  	op, err := db.Query().ExecuteScript(ctx,
   114  		"SELECT val FROM `"+path.Join(db.Name(), folder, tableName)+"`;",
   115  		time.Hour,
   116  	)
   117  	require.NoError(t, err)
   118  
   119  	for {
   120  		status, err := db.Operation().Get(ctx, op.ID)
   121  		require.NoError(t, err)
   122  		if status.Ready {
   123  			break
   124  		}
   125  		time.Sleep(time.Second)
   126  	}
   127  
   128  	var (
   129  		nextToken string
   130  		rowsCount = 0
   131  		checkSum  = uint64(0)
   132  	)
   133  	for {
   134  		result, err := db.Query().FetchScriptResults(ctx, op.ID,
   135  			query.WithResultSetIndex(0),
   136  			query.WithRowsLimit(1000),
   137  			query.WithFetchToken(nextToken),
   138  		)
   139  		require.NoError(t, err)
   140  		nextToken = result.NextToken
   141  		require.EqualValues(t, 0, result.ResultSetIndex)
   142  		t.Logf("reading next 1000 rows. Current rows count: %v\n", rowsCount)
   143  		for {
   144  			row, err := result.ResultSet.NextRow(ctx)
   145  			if err != nil {
   146  				if errors.Is(err, io.EOF) {
   147  					break
   148  				}
   149  				t.Fatal(err)
   150  			}
   151  			rowsCount++
   152  			var val int64
   153  			err = row.Scan(&val)
   154  			checkSum += uint64(val)
   155  		}
   156  		if result.NextToken == "" {
   157  			break
   158  		}
   159  	}
   160  	require.EqualValues(t, upsertRowsCount, rowsCount)
   161  	require.EqualValues(t, expectedCheckSum, checkSum)
   162  }