github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/tests/integration/table_cross_join_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package integration
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/table/options"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    17  )
    18  
    19  func TestTableCrossJoin(t *testing.T) {
    20  	var (
    21  		ctx        = xtest.Context(t)
    22  		scope      = newScope(t)
    23  		db         = scope.Driver()
    24  		table1Path = scope.TablePath(
    25  			withTableName("table1"),
    26  			withCreateTableOptions(
    27  				options.WithColumn("p1", types.Optional(types.TypeText)),
    28  				options.WithPrimaryKeyColumn("p1"),
    29  			),
    30  		)
    31  		_ = scope.TablePath(
    32  			withTableName("table2"),
    33  			withCreateTableOptions(
    34  				options.WithColumn("p1", types.Optional(types.TypeText)),
    35  				options.WithPrimaryKeyColumn("p1"),
    36  			),
    37  		)
    38  	)
    39  	// upsert data into table1
    40  	err := db.Table().Do(ctx,
    41  		func(ctx context.Context, s table.Session) (err error) {
    42  			return s.BulkUpsert(ctx, table1Path, types.ListValue(types.StructValue(
    43  				types.StructFieldValue("p1", types.TextValue("foo")),
    44  			)))
    45  		},
    46  		table.WithIdempotent(),
    47  	)
    48  	scope.Require.NoError(err)
    49  
    50  	for _, tt := range []struct {
    51  		name       string
    52  		subQuery   string
    53  		withParams bool
    54  	}{
    55  		{
    56  			name: "Data1FromTable1Data2FromEmptyListWithoutParams",
    57  			subQuery: `
    58  				$data1 = (SELECT * FROM table1);
    59  				$data2 = Cast(AsList() As List<Struct<p2: Utf8>>);
    60  			`,
    61  			withParams: false,
    62  		},
    63  		{
    64  			name: "Data1FromTable1Data2FromLiteralWithoutParams",
    65  			subQuery: `
    66  				$data1 = (SELECT * FROM table1);
    67  				$data2 = Cast(AsList(AsStruct(CAST("t1" AS Utf8) AS p2)) As List<Struct<p2: Utf8>>);
    68  			`,
    69  			withParams: false,
    70  		},
    71  		// failed test-case
    72  		//{
    73  		//	name: "Data1FromTable1DeclareData2WithParams",
    74  		//	subQuery: `
    75  		//		DECLARE $data2 AS List<Struct<p2: Utf8>>;
    76  		//		$data1 = (SELECT * FROM table1);
    77  		//	`,
    78  		//	withParams: true,
    79  		//},
    80  		{
    81  			name: "Data1FromLiteralDeclareData2WithParams",
    82  			subQuery: `
    83  				DECLARE $data2 AS List<Struct<p2: Utf8>>;
    84  				$data1 = (SELECT * FROM AS_TABLE(AsList(AsStruct(CAST("foo" AS Utf8?) AS p1))));
    85  			`,
    86  			withParams: true,
    87  		},
    88  	} {
    89  		t.Run(tt.name, func(t *testing.T) {
    90  			query := `--!syntax_v1
    91  				PRAGMA TablePathPrefix("` + scope.Folder() + `");
    92  				
    93  				/* sub-query */` + tt.subQuery + `
    94  				/* query */
    95  				UPSERT INTO table2
    96  				SELECT d1.p1 AS p1,
    97  				FROM $data1 AS d1
    98  				CROSS JOIN AS_TABLE($data2) AS d2;
    99  		
   100  				SELECT COUNT(*) FROM $data1;
   101  			`
   102  
   103  			params := table.NewQueryParameters()
   104  			if tt.withParams {
   105  				params = table.NewQueryParameters(
   106  					table.ValueParam("$data2", types.ZeroValue(types.List(types.Struct(types.StructField("p2", types.TypeUTF8))))),
   107  				)
   108  			}
   109  
   110  			var got uint64
   111  			err = db.Table().Do(ctx, func(c context.Context, s table.Session) (err error) {
   112  				_, res, err := s.Execute(c, table.DefaultTxControl(), query, params)
   113  				if err != nil {
   114  					return err
   115  				}
   116  				defer res.Close()
   117  				if !res.NextResultSet(ctx) {
   118  					return fmt.Errorf("no result set")
   119  				}
   120  				if !res.NextRow() {
   121  					return fmt.Errorf("no rows")
   122  				}
   123  				if err = res.ScanWithDefaults(&got); err != nil {
   124  					return err
   125  				}
   126  				return res.Err()
   127  			}, table.WithIdempotent())
   128  			require.NoError(t, err, query)
   129  			require.Equal(t, uint64(1), got, query)
   130  		})
   131  	}
   132  }