github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/bind/params_test.go (about)

     1  package bind
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/require"
    11  
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    15  )
    16  
    17  func TestToValue(t *testing.T) {
    18  	for _, tt := range []struct {
    19  		src interface{}
    20  		dst types.Value
    21  		err error
    22  	}{
    23  		{
    24  			src: types.BoolValue(true),
    25  			dst: types.BoolValue(true),
    26  			err: nil,
    27  		},
    28  
    29  		{
    30  			src: nil,
    31  			dst: types.VoidValue(),
    32  			err: nil,
    33  		},
    34  
    35  		{
    36  			src: true,
    37  			dst: types.BoolValue(true),
    38  			err: nil,
    39  		},
    40  		{
    41  			src: func(v bool) *bool { return &v }(true),
    42  			dst: types.OptionalValue(types.BoolValue(true)),
    43  			err: nil,
    44  		},
    45  		{
    46  			src: func() *bool { return nil }(),
    47  			dst: types.NullValue(types.TypeBool),
    48  			err: nil,
    49  		},
    50  
    51  		{
    52  			src: 42,
    53  			dst: types.Int32Value(42),
    54  			err: nil,
    55  		},
    56  		{
    57  			src: func(v int) *int { return &v }(42),
    58  			dst: types.OptionalValue(types.Int32Value(42)),
    59  			err: nil,
    60  		},
    61  		{
    62  			src: func() *int { return nil }(),
    63  			dst: types.NullValue(types.TypeInt32),
    64  			err: nil,
    65  		},
    66  
    67  		{
    68  			src: uint(42),
    69  			dst: types.Uint32Value(42),
    70  			err: nil,
    71  		},
    72  		{
    73  			src: func(v uint) *uint { return &v }(42),
    74  			dst: types.OptionalValue(types.Uint32Value(42)),
    75  			err: nil,
    76  		},
    77  		{
    78  			src: func() *uint { return nil }(),
    79  			dst: types.NullValue(types.TypeUint32),
    80  			err: nil,
    81  		},
    82  
    83  		{
    84  			src: int8(42),
    85  			dst: types.Int8Value(42),
    86  			err: nil,
    87  		},
    88  		{
    89  			src: func(v int8) *int8 { return &v }(42),
    90  			dst: types.OptionalValue(types.Int8Value(42)),
    91  			err: nil,
    92  		},
    93  		{
    94  			src: func() *int8 { return nil }(),
    95  			dst: types.NullValue(types.TypeInt8),
    96  			err: nil,
    97  		},
    98  
    99  		{
   100  			src: uint8(42),
   101  			dst: types.Uint8Value(42),
   102  			err: nil,
   103  		},
   104  		{
   105  			src: func(v uint8) *uint8 { return &v }(42),
   106  			dst: types.OptionalValue(types.Uint8Value(42)),
   107  			err: nil,
   108  		},
   109  		{
   110  			src: func() *uint8 { return nil }(),
   111  			dst: types.NullValue(types.TypeUint8),
   112  			err: nil,
   113  		},
   114  
   115  		{
   116  			src: int16(42),
   117  			dst: types.Int16Value(42),
   118  			err: nil,
   119  		},
   120  		{
   121  			src: func(v int16) *int16 { return &v }(42),
   122  			dst: types.OptionalValue(types.Int16Value(42)),
   123  			err: nil,
   124  		},
   125  		{
   126  			src: func() *int16 { return nil }(),
   127  			dst: types.NullValue(types.TypeInt16),
   128  			err: nil,
   129  		},
   130  
   131  		{
   132  			src: uint16(42),
   133  			dst: types.Uint16Value(42),
   134  			err: nil,
   135  		},
   136  		{
   137  			src: func(v uint16) *uint16 { return &v }(42),
   138  			dst: types.OptionalValue(types.Uint16Value(42)),
   139  			err: nil,
   140  		},
   141  		{
   142  			src: func() *uint16 { return nil }(),
   143  			dst: types.NullValue(types.TypeUint16),
   144  			err: nil,
   145  		},
   146  
   147  		{
   148  			src: int32(42),
   149  			dst: types.Int32Value(42),
   150  			err: nil,
   151  		},
   152  		{
   153  			src: func(v int32) *int32 { return &v }(42),
   154  			dst: types.OptionalValue(types.Int32Value(42)),
   155  			err: nil,
   156  		},
   157  		{
   158  			src: func() *int32 { return nil }(),
   159  			dst: types.NullValue(types.TypeInt32),
   160  			err: nil,
   161  		},
   162  
   163  		{
   164  			src: uint32(42),
   165  			dst: types.Uint32Value(42),
   166  			err: nil,
   167  		},
   168  		{
   169  			src: func(v uint32) *uint32 { return &v }(42),
   170  			dst: types.OptionalValue(types.Uint32Value(42)),
   171  			err: nil,
   172  		},
   173  		{
   174  			src: func() *uint32 { return nil }(),
   175  			dst: types.NullValue(types.TypeUint32),
   176  			err: nil,
   177  		},
   178  
   179  		{
   180  			src: int64(42),
   181  			dst: types.Int64Value(42),
   182  			err: nil,
   183  		},
   184  		{
   185  			src: func(v int64) *int64 { return &v }(42),
   186  			dst: types.OptionalValue(types.Int64Value(42)),
   187  			err: nil,
   188  		},
   189  		{
   190  			src: func() *int64 { return nil }(),
   191  			dst: types.NullValue(types.TypeInt64),
   192  			err: nil,
   193  		},
   194  
   195  		{
   196  			src: uint64(42),
   197  			dst: types.Uint64Value(42),
   198  			err: nil,
   199  		},
   200  		{
   201  			src: func(v uint64) *uint64 { return &v }(42),
   202  			dst: types.OptionalValue(types.Uint64Value(42)),
   203  			err: nil,
   204  		},
   205  		{
   206  			src: func() *uint64 { return nil }(),
   207  			dst: types.NullValue(types.TypeUint64),
   208  			err: nil,
   209  		},
   210  
   211  		{
   212  			src: float32(42),
   213  			dst: types.FloatValue(42),
   214  			err: nil,
   215  		},
   216  		{
   217  			src: func(v float32) *float32 { return &v }(42),
   218  			dst: types.OptionalValue(types.FloatValue(42)),
   219  			err: nil,
   220  		},
   221  		{
   222  			src: func() *float32 { return nil }(),
   223  			dst: types.NullValue(types.TypeFloat),
   224  			err: nil,
   225  		},
   226  
   227  		{
   228  			src: float64(42),
   229  			dst: types.DoubleValue(42),
   230  			err: nil,
   231  		},
   232  		{
   233  			src: func(v float64) *float64 { return &v }(42),
   234  			dst: types.OptionalValue(types.DoubleValue(42)),
   235  			err: nil,
   236  		},
   237  		{
   238  			src: func() *float64 { return nil }(),
   239  			dst: types.NullValue(types.TypeDouble),
   240  			err: nil,
   241  		},
   242  
   243  		{
   244  			src: "test",
   245  			dst: types.TextValue("test"),
   246  			err: nil,
   247  		},
   248  		{
   249  			src: func(v string) *string { return &v }("test"),
   250  			dst: types.OptionalValue(types.TextValue("test")),
   251  			err: nil,
   252  		},
   253  		{
   254  			src: func() *string { return nil }(),
   255  			dst: types.NullValue(types.TypeText),
   256  			err: nil,
   257  		},
   258  
   259  		{
   260  			src: []byte("test"),
   261  			dst: types.BytesValue([]byte("test")),
   262  			err: nil,
   263  		},
   264  		{
   265  			src: func(v []byte) *[]byte { return &v }([]byte("test")),
   266  			dst: types.OptionalValue(types.BytesValue([]byte("test"))),
   267  			err: nil,
   268  		},
   269  		{
   270  			src: func() *[]byte { return nil }(),
   271  			dst: types.NullValue(types.TypeBytes),
   272  			err: nil,
   273  		},
   274  
   275  		{
   276  			src: []string{"test"},
   277  			dst: types.ListValue(types.TextValue("test")),
   278  			err: nil,
   279  		},
   280  
   281  		{
   282  			src: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   283  			dst: types.UUIDValue([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}),
   284  			err: nil,
   285  		},
   286  		{
   287  			src: func(v [16]byte) *[16]byte { return &v }([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}),
   288  			dst: types.OptionalValue(types.UUIDValue([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})),
   289  			err: nil,
   290  		},
   291  		{
   292  			src: func() *[16]byte { return nil }(),
   293  			dst: types.NullValue(types.TypeUUID),
   294  			err: nil,
   295  		},
   296  
   297  		{
   298  			src: time.Unix(42, 43),
   299  			dst: types.TimestampValueFromTime(time.Unix(42, 43)),
   300  			err: nil,
   301  		},
   302  		{
   303  			src: func(v time.Time) *time.Time { return &v }(time.Unix(42, 43)),
   304  			dst: types.OptionalValue(types.TimestampValueFromTime(time.Unix(42, 43))),
   305  			err: nil,
   306  		},
   307  		{
   308  			src: func() *time.Time { return nil }(),
   309  			dst: types.NullValue(types.TypeTimestamp),
   310  			err: nil,
   311  		},
   312  
   313  		{
   314  			src: time.Duration(42),
   315  			dst: types.IntervalValueFromDuration(time.Duration(42)),
   316  			err: nil,
   317  		},
   318  		{
   319  			src: func(v time.Duration) *time.Duration { return &v }(time.Duration(42)),
   320  			dst: types.OptionalValue(types.IntervalValueFromDuration(time.Duration(42))),
   321  			err: nil,
   322  		},
   323  		{
   324  			src: func() *time.Duration { return nil }(),
   325  			dst: types.NullValue(types.TypeInterval),
   326  			err: nil,
   327  		},
   328  	} {
   329  		t.Run(fmt.Sprintf("%T(%v)", tt.src, tt.src), func(t *testing.T) {
   330  			dst, err := toValue(tt.src)
   331  			if tt.err != nil {
   332  				require.ErrorIs(t, err, tt.err)
   333  			} else {
   334  				require.Equal(t, tt.dst, dst)
   335  			}
   336  		})
   337  	}
   338  }
   339  
   340  func named(name string, value interface{}) driver.NamedValue {
   341  	return driver.NamedValue{
   342  		Name:  name,
   343  		Value: value,
   344  	}
   345  }
   346  
   347  func TestYdbParam(t *testing.T) {
   348  	for _, tt := range []struct {
   349  		src interface{}
   350  		dst *params.Parameter
   351  		err error
   352  	}{
   353  		{
   354  			src: params.Named("$a", types.Int32Value(42)),
   355  			dst: params.Named("$a", types.Int32Value(42)),
   356  			err: nil,
   357  		},
   358  		{
   359  			src: named("a", int(42)),
   360  			dst: params.Named("$a", types.Int32Value(42)),
   361  			err: nil,
   362  		},
   363  		{
   364  			src: named("$a", int(42)),
   365  			dst: params.Named("$a", types.Int32Value(42)),
   366  			err: nil,
   367  		},
   368  		{
   369  			src: named("a", uint(42)),
   370  			dst: params.Named("$a", types.Uint32Value(42)),
   371  			err: nil,
   372  		},
   373  		{
   374  			src: driver.NamedValue{Value: uint(42)},
   375  			dst: nil,
   376  			err: errUnnamedParam,
   377  		},
   378  	} {
   379  		t.Run("", func(t *testing.T) {
   380  			dst, err := toYdbParam("", tt.src)
   381  			if tt.err != nil {
   382  				require.ErrorIs(t, err, tt.err)
   383  			} else {
   384  				require.Equal(t, tt.dst, dst)
   385  			}
   386  		})
   387  	}
   388  }
   389  
   390  func TestArgsToParams(t *testing.T) {
   391  	for _, tt := range []struct {
   392  		args   []interface{}
   393  		params []*params.Parameter
   394  		err    error
   395  	}{
   396  		{
   397  			args:   []interface{}{},
   398  			params: []*params.Parameter{},
   399  			err:    nil,
   400  		},
   401  		{
   402  			args: []interface{}{
   403  				1, uint64(2), "3",
   404  			},
   405  			params: []*params.Parameter{
   406  				params.Named("$p0", types.Int32Value(1)),
   407  				params.Named("$p1", types.Uint64Value(2)),
   408  				params.Named("$p2", types.TextValue("3")),
   409  			},
   410  			err: nil,
   411  		},
   412  		{
   413  			args: []interface{}{
   414  				table.NewQueryParameters(
   415  					params.Named("$p0", types.Int32Value(1)),
   416  					params.Named("$p1", types.Uint64Value(2)),
   417  					params.Named("$p2", types.TextValue("3")),
   418  				),
   419  				table.NewQueryParameters(
   420  					params.Named("$p0", types.Int32Value(1)),
   421  					params.Named("$p1", types.Uint64Value(2)),
   422  					params.Named("$p2", types.TextValue("3")),
   423  				),
   424  			},
   425  			err: errMultipleQueryParameters,
   426  		},
   427  		{
   428  			args: []interface{}{
   429  				params.Named("$p0", types.Int32Value(1)),
   430  				params.Named("$p1", types.Uint64Value(2)),
   431  				params.Named("$p2", types.TextValue("3")),
   432  			},
   433  			params: []*params.Parameter{
   434  				params.Named("$p0", types.Int32Value(1)),
   435  				params.Named("$p1", types.Uint64Value(2)),
   436  				params.Named("$p2", types.TextValue("3")),
   437  			},
   438  			err: nil,
   439  		},
   440  		{
   441  			args: []interface{}{
   442  				sql.Named("$p0", types.Int32Value(1)),
   443  				sql.Named("$p1", types.Uint64Value(2)),
   444  				sql.Named("$p2", types.TextValue("3")),
   445  			},
   446  			params: []*params.Parameter{
   447  				params.Named("$p0", types.Int32Value(1)),
   448  				params.Named("$p1", types.Uint64Value(2)),
   449  				params.Named("$p2", types.TextValue("3")),
   450  			},
   451  			err: nil,
   452  		},
   453  		{
   454  			args: []interface{}{
   455  				driver.NamedValue{Name: "$p0", Value: types.Int32Value(1)},
   456  				driver.NamedValue{Name: "$p1", Value: types.Uint64Value(2)},
   457  				driver.NamedValue{Name: "$p2", Value: types.TextValue("3")},
   458  			},
   459  			params: []*params.Parameter{
   460  				params.Named("$p0", types.Int32Value(1)),
   461  				params.Named("$p1", types.Uint64Value(2)),
   462  				params.Named("$p2", types.TextValue("3")),
   463  			},
   464  			err: nil,
   465  		},
   466  		{
   467  			args: []interface{}{
   468  				driver.NamedValue{Value: params.Named("$p0", types.Int32Value(1))},
   469  				driver.NamedValue{Value: params.Named("$p1", types.Uint64Value(2))},
   470  				driver.NamedValue{Value: params.Named("$p2", types.TextValue("3"))},
   471  			},
   472  			params: []*params.Parameter{
   473  				params.Named("$p0", types.Int32Value(1)),
   474  				params.Named("$p1", types.Uint64Value(2)),
   475  				params.Named("$p2", types.TextValue("3")),
   476  			},
   477  			err: nil,
   478  		},
   479  		{
   480  			args: []interface{}{
   481  				driver.NamedValue{Value: 1},
   482  				driver.NamedValue{Value: uint64(2)},
   483  				driver.NamedValue{Value: "3"},
   484  			},
   485  			params: []*params.Parameter{
   486  				params.Named("$p0", types.Int32Value(1)),
   487  				params.Named("$p1", types.Uint64Value(2)),
   488  				params.Named("$p2", types.TextValue("3")),
   489  			},
   490  			err: nil,
   491  		},
   492  		{
   493  			args: []interface{}{
   494  				driver.NamedValue{Value: table.NewQueryParameters(
   495  					params.Named("$p0", types.Int32Value(1)),
   496  					params.Named("$p1", types.Uint64Value(2)),
   497  					params.Named("$p2", types.TextValue("3")),
   498  				)},
   499  			},
   500  			params: []*params.Parameter{
   501  				params.Named("$p0", types.Int32Value(1)),
   502  				params.Named("$p1", types.Uint64Value(2)),
   503  				params.Named("$p2", types.TextValue("3")),
   504  			},
   505  			err: nil,
   506  		},
   507  		{
   508  			args: []interface{}{
   509  				driver.NamedValue{Value: table.NewQueryParameters(
   510  					params.Named("$p0", types.Int32Value(1)),
   511  					params.Named("$p1", types.Uint64Value(2)),
   512  					params.Named("$p2", types.TextValue("3")),
   513  				)},
   514  				driver.NamedValue{Value: params.Named("$p1", types.Uint64Value(2))},
   515  				driver.NamedValue{Value: params.Named("$p2", types.TextValue("3"))},
   516  			},
   517  			err: errMultipleQueryParameters,
   518  		},
   519  	} {
   520  		t.Run("", func(t *testing.T) {
   521  			params, err := Params(tt.args...)
   522  			if tt.err != nil {
   523  				require.ErrorIs(t, err, tt.err)
   524  			} else {
   525  				require.NoError(t, err)
   526  				require.Equal(t, tt.params, params)
   527  			}
   528  		})
   529  	}
   530  }