github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/sugar/params_test.go (about)

     1  package sugar
     2  
     3  import (
     4  	"database/sql"
     5  	"sort"
     6  	"strings"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/testutil"
    17  )
    18  
    19  func TestGenerateDeclareSection(t *testing.T) {
    20  	splitDeclares := func(declaresSection string) (declares []string) {
    21  		for _, s := range strings.Split(declaresSection, ";") {
    22  			s = strings.TrimSpace(s)
    23  			if s != "" {
    24  				declares = append(declares, s)
    25  			}
    26  		}
    27  		sort.Strings(declares)
    28  
    29  		return declares
    30  	}
    31  	for _, tt := range []struct {
    32  		params  *table.QueryParameters
    33  		declare string
    34  	}{
    35  		{
    36  			params: table.NewQueryParameters(
    37  				table.ValueParam(
    38  					"$values",
    39  					types.ListValue(
    40  						types.Uint64Value(1),
    41  						types.Uint64Value(2),
    42  						types.Uint64Value(3),
    43  						types.Uint64Value(4),
    44  						types.Uint64Value(5),
    45  					),
    46  				),
    47  			),
    48  			declare: `
    49  				DECLARE $values AS List<Uint64>;
    50  			`,
    51  		},
    52  		{
    53  			params: table.NewQueryParameters(
    54  				table.ValueParam(
    55  					"$delta",
    56  					types.IntervalValueFromDuration(time.Hour),
    57  				),
    58  			),
    59  			declare: `
    60  				DECLARE $delta AS Interval;
    61  			`,
    62  		},
    63  		{
    64  			params: table.NewQueryParameters(
    65  				table.ValueParam("$ts", types.TimestampValueFromTime(time.Now())),
    66  			),
    67  			declare: `
    68  				DECLARE $ts AS Timestamp;
    69  			`,
    70  		},
    71  		{
    72  			params: table.NewQueryParameters(
    73  				table.ValueParam("$a", types.BoolValue(true)),
    74  				table.ValueParam("$b", types.Int64Value(123)),
    75  				table.ValueParam("$c", types.OptionalValue(types.TextValue("test"))),
    76  			),
    77  			declare: `
    78  				DECLARE $a AS Bool;
    79  				DECLARE $b AS Int64; 
    80  				DECLARE $c AS Optional<Utf8>;
    81  			`,
    82  		},
    83  		{
    84  			params: table.NewQueryParameters(
    85  				table.ValueParam("$a", types.BoolValue(true)),
    86  				table.ValueParam("b", types.Int64Value(123)),
    87  				table.ValueParam("c", types.OptionalValue(types.TextValue("test"))),
    88  			),
    89  			declare: `
    90  				DECLARE $a AS Bool;
    91  				DECLARE $b AS Int64; 
    92  				DECLARE $c AS Optional<Utf8>;
    93  			`,
    94  		},
    95  	} {
    96  		t.Run("", func(t *testing.T) {
    97  			declares, err := GenerateDeclareSection(tt.params)
    98  			require.NoError(t, err)
    99  			got := splitDeclares(declares)
   100  			want := splitDeclares(tt.declare)
   101  			if len(got) != len(want) {
   102  				t.Errorf("len(got) = %v, len(want) = %v", len(got), len(want))
   103  			} else {
   104  				for i := range got {
   105  					if strings.TrimSpace(got[i]) != strings.TrimSpace(want[i]) {
   106  						t.Errorf(
   107  							"unexpected generation of declare section:\n%v\n\nwant:\n%v",
   108  							strings.Join(got, ";\n"),
   109  							strings.Join(want, ";\n"),
   110  						)
   111  					}
   112  				}
   113  			}
   114  		})
   115  	}
   116  }
   117  
   118  func TestGenerateDeclareSection_ParameterOption(t *testing.T) {
   119  	b := testutil.QueryBind(bind.AutoDeclare{})
   120  	getDeclares := func(declaresSection string) (declares []string) {
   121  		for _, s := range strings.Split(declaresSection, "\n") {
   122  			s = strings.TrimSpace(s)
   123  			if s != "" && !strings.HasPrefix(s, "--") {
   124  				declares = append(declares, strings.TrimRight(s, ";"))
   125  			}
   126  		}
   127  		sort.Strings(declares)
   128  
   129  		return declares
   130  	}
   131  	for _, tt := range []struct {
   132  		params   []interface{}
   133  		declares []string
   134  	}{
   135  		{
   136  			params: []interface{}{
   137  				table.ValueParam(
   138  					"$values",
   139  					types.ListValue(
   140  						types.Uint64Value(1),
   141  						types.Uint64Value(2),
   142  						types.Uint64Value(3),
   143  						types.Uint64Value(4),
   144  						types.Uint64Value(5),
   145  					),
   146  				),
   147  			},
   148  			declares: []string{
   149  				"DECLARE $values AS List<Uint64>",
   150  			},
   151  		},
   152  		{
   153  			params: []interface{}{
   154  				table.ValueParam(
   155  					"$delta",
   156  					types.IntervalValueFromDuration(time.Hour),
   157  				),
   158  			},
   159  			declares: []string{
   160  				"DECLARE $delta AS Interval",
   161  			},
   162  		},
   163  		{
   164  			params: []interface{}{
   165  				table.ValueParam(
   166  					"$ts",
   167  					types.TimestampValueFromTime(time.Now()),
   168  				),
   169  			},
   170  			declares: []string{
   171  				"DECLARE $ts AS Timestamp",
   172  			},
   173  		},
   174  		{
   175  			params: []interface{}{
   176  				table.ValueParam(
   177  					"$a",
   178  					types.BoolValue(true),
   179  				),
   180  				table.ValueParam(
   181  					"$b",
   182  					types.Int64Value(123),
   183  				),
   184  				table.ValueParam(
   185  					"$c",
   186  					types.OptionalValue(types.TextValue("test")),
   187  				),
   188  			},
   189  			declares: []string{
   190  				"DECLARE $a AS Bool",
   191  				"DECLARE $b AS Int64",
   192  				"DECLARE $c AS Optional<Utf8>",
   193  			},
   194  		},
   195  		{
   196  			params: []interface{}{
   197  				table.ValueParam(
   198  					"$a",
   199  					types.BoolValue(true),
   200  				),
   201  				table.ValueParam(
   202  					"b",
   203  					types.Int64Value(123),
   204  				),
   205  				table.ValueParam(
   206  					"c",
   207  					types.OptionalValue(types.TextValue("test")),
   208  				),
   209  			},
   210  			declares: []string{
   211  				"DECLARE $a AS Bool",
   212  				"DECLARE $b AS Int64",
   213  				"DECLARE $c AS Optional<Utf8>",
   214  			},
   215  		},
   216  	} {
   217  		t.Run("", func(t *testing.T) {
   218  			yql, _, err := b.RewriteQuery("", tt.params...)
   219  			require.NoError(t, err)
   220  			require.Equal(t, tt.declares, getDeclares(yql))
   221  		})
   222  	}
   223  }
   224  
   225  func TestGenerateDeclareSection_NamedArg(t *testing.T) {
   226  	b := testutil.QueryBind(bind.AutoDeclare{})
   227  	getDeclares := func(declaresSection string) (declares []string) {
   228  		for _, s := range strings.Split(declaresSection, "\n") {
   229  			s = strings.TrimSpace(s)
   230  			if s != "" && !strings.HasPrefix(s, "--") {
   231  				declares = append(declares, strings.TrimRight(s, ";"))
   232  			}
   233  		}
   234  		sort.Strings(declares)
   235  
   236  		return declares
   237  	}
   238  	for _, tt := range []struct {
   239  		params   []interface{}
   240  		declares []string
   241  	}{
   242  		{
   243  			params: []interface{}{
   244  				sql.Named(
   245  					"values",
   246  					types.ListValue(
   247  						types.Uint64Value(1),
   248  						types.Uint64Value(2),
   249  						types.Uint64Value(3),
   250  						types.Uint64Value(4),
   251  						types.Uint64Value(5),
   252  					),
   253  				),
   254  			},
   255  			declares: []string{
   256  				"DECLARE $values AS List<Uint64>",
   257  			},
   258  		},
   259  		{
   260  			params: []interface{}{
   261  				sql.Named(
   262  					"delta",
   263  					types.IntervalValueFromDuration(time.Hour),
   264  				),
   265  			},
   266  			declares: []string{
   267  				"DECLARE $delta AS Interval",
   268  			},
   269  		},
   270  		{
   271  			params: []interface{}{
   272  				sql.Named(
   273  					"ts",
   274  					types.TimestampValueFromTime(time.Now()),
   275  				),
   276  			},
   277  			declares: []string{
   278  				"DECLARE $ts AS Timestamp",
   279  			},
   280  		},
   281  		{
   282  			params: []interface{}{
   283  				sql.Named(
   284  					"a",
   285  					types.BoolValue(true),
   286  				),
   287  				sql.Named(
   288  					"b",
   289  					types.Int64Value(123),
   290  				),
   291  				sql.Named(
   292  					"c",
   293  					types.OptionalValue(types.TextValue("test")),
   294  				),
   295  			},
   296  			declares: []string{
   297  				"DECLARE $a AS Bool",
   298  				"DECLARE $b AS Int64",
   299  				"DECLARE $c AS Optional<Utf8>",
   300  			},
   301  		},
   302  		{
   303  			params: []interface{}{
   304  				sql.Named(
   305  					"a",
   306  					types.BoolValue(true),
   307  				),
   308  				sql.Named(
   309  					"b",
   310  					types.Int64Value(123),
   311  				),
   312  				sql.Named(
   313  					"c",
   314  					types.OptionalValue(types.TextValue("test")),
   315  				),
   316  			},
   317  			declares: []string{
   318  				"DECLARE $a AS Bool",
   319  				"DECLARE $b AS Int64",
   320  				"DECLARE $c AS Optional<Utf8>",
   321  			},
   322  		},
   323  
   324  		{
   325  			params: []interface{}{
   326  				sql.Named("delta", time.Hour),
   327  			},
   328  			declares: []string{
   329  				"DECLARE $delta AS Interval",
   330  			},
   331  		},
   332  		{
   333  			params: []interface{}{
   334  				sql.Named("ts", time.Now()),
   335  			},
   336  			declares: []string{
   337  				"DECLARE $ts AS Timestamp",
   338  			},
   339  		},
   340  		{
   341  			params: []interface{}{
   342  				sql.Named("$a", true),
   343  				sql.Named("$b", int64(123)),
   344  				sql.Named("$c", func(s string) *string { return &s }("test")),
   345  			},
   346  			declares: []string{
   347  				"DECLARE $a AS Bool",
   348  				"DECLARE $b AS Int64",
   349  				"DECLARE $c AS Optional<Utf8>",
   350  			},
   351  		},
   352  		{
   353  			params: []interface{}{
   354  				sql.Named("$a", func(b bool) *bool { return &b }(true)),
   355  				sql.Named("b", func(i int64) *int64 { return &i }(123)),
   356  				sql.Named("c", func(s string) *string { return &s }("test")),
   357  			},
   358  			declares: []string{
   359  				"DECLARE $a AS Optional<Bool>",
   360  				"DECLARE $b AS Optional<Int64>",
   361  				"DECLARE $c AS Optional<Utf8>",
   362  			},
   363  		},
   364  	} {
   365  		t.Run("", func(t *testing.T) {
   366  			yql, _, err := b.RewriteQuery("", tt.params...)
   367  			require.NoError(t, err)
   368  			require.Equal(t, tt.declares, getDeclares(yql))
   369  		})
   370  	}
   371  }
   372  
   373  func TestToYdbParam(t *testing.T) {
   374  	for _, tt := range []struct {
   375  		name     string
   376  		param    sql.NamedArg
   377  		ydbParam table.ParameterOption
   378  		err      error
   379  	}{
   380  		{
   381  			name:     xtest.CurrentFileLine(),
   382  			param:    sql.Named("a", "b"),
   383  			ydbParam: table.ValueParam("$a", types.TextValue("b")),
   384  			err:      nil,
   385  		},
   386  		{
   387  			name:     xtest.CurrentFileLine(),
   388  			param:    sql.Named("a", 123),
   389  			ydbParam: table.ValueParam("$a", types.Int32Value(123)),
   390  			err:      nil,
   391  		},
   392  		{
   393  			name: xtest.CurrentFileLine(),
   394  			param: sql.Named("a", types.OptionalValue(types.TupleValue(
   395  				types.BytesValue([]byte("test")),
   396  				types.TextValue("test"),
   397  				types.Uint64Value(123),
   398  			))),
   399  			ydbParam: table.ValueParam("$a", types.OptionalValue(types.TupleValue(
   400  				types.BytesValue([]byte("test")),
   401  				types.TextValue("test"),
   402  				types.Uint64Value(123),
   403  			))),
   404  			err: nil,
   405  		},
   406  	} {
   407  		t.Run(tt.name, func(t *testing.T) {
   408  			ydbParam, err := ToYdbParam(tt.param)
   409  			if tt.err != nil {
   410  				require.Error(t, err)
   411  			} else {
   412  				require.NoError(t, err)
   413  				require.Equal(t, tt.ydbParam, ydbParam)
   414  			}
   415  		})
   416  	}
   417  }