github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/bind/params_test.go (about)

     1  package bind
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/google/uuid"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/params"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
    17  )
    18  
    19  func TestToValue(t *testing.T) {
    20  	for _, tt := range []struct {
    21  		src interface{}
    22  		dst types.Value
    23  		err error
    24  	}{
    25  		{
    26  			src: types.BoolValue(true),
    27  			dst: types.BoolValue(true),
    28  			err: nil,
    29  		},
    30  
    31  		{
    32  			src: nil,
    33  			dst: types.VoidValue(),
    34  			err: nil,
    35  		},
    36  
    37  		{
    38  			src: true,
    39  			dst: types.BoolValue(true),
    40  			err: nil,
    41  		},
    42  		{
    43  			src: func(v bool) *bool { return &v }(true),
    44  			dst: types.OptionalValue(types.BoolValue(true)),
    45  			err: nil,
    46  		},
    47  		{
    48  			src: func() *bool { return nil }(),
    49  			dst: types.NullValue(types.TypeBool),
    50  			err: nil,
    51  		},
    52  
    53  		{
    54  			src: 42,
    55  			dst: types.Int32Value(42),
    56  			err: nil,
    57  		},
    58  		{
    59  			src: func(v int) *int { return &v }(42),
    60  			dst: types.OptionalValue(types.Int32Value(42)),
    61  			err: nil,
    62  		},
    63  		{
    64  			src: func() *int { return nil }(),
    65  			dst: types.NullValue(types.TypeInt32),
    66  			err: nil,
    67  		},
    68  
    69  		{
    70  			src: uint(42),
    71  			dst: types.Uint32Value(42),
    72  			err: nil,
    73  		},
    74  		{
    75  			src: func(v uint) *uint { return &v }(42),
    76  			dst: types.OptionalValue(types.Uint32Value(42)),
    77  			err: nil,
    78  		},
    79  		{
    80  			src: func() *uint { return nil }(),
    81  			dst: types.NullValue(types.TypeUint32),
    82  			err: nil,
    83  		},
    84  
    85  		{
    86  			src: int8(42),
    87  			dst: types.Int8Value(42),
    88  			err: nil,
    89  		},
    90  		{
    91  			src: func(v int8) *int8 { return &v }(42),
    92  			dst: types.OptionalValue(types.Int8Value(42)),
    93  			err: nil,
    94  		},
    95  		{
    96  			src: func() *int8 { return nil }(),
    97  			dst: types.NullValue(types.TypeInt8),
    98  			err: nil,
    99  		},
   100  
   101  		{
   102  			src: uint8(42),
   103  			dst: types.Uint8Value(42),
   104  			err: nil,
   105  		},
   106  		{
   107  			src: func(v uint8) *uint8 { return &v }(42),
   108  			dst: types.OptionalValue(types.Uint8Value(42)),
   109  			err: nil,
   110  		},
   111  		{
   112  			src: func() *uint8 { return nil }(),
   113  			dst: types.NullValue(types.TypeUint8),
   114  			err: nil,
   115  		},
   116  
   117  		{
   118  			src: int16(42),
   119  			dst: types.Int16Value(42),
   120  			err: nil,
   121  		},
   122  		{
   123  			src: func(v int16) *int16 { return &v }(42),
   124  			dst: types.OptionalValue(types.Int16Value(42)),
   125  			err: nil,
   126  		},
   127  		{
   128  			src: func() *int16 { return nil }(),
   129  			dst: types.NullValue(types.TypeInt16),
   130  			err: nil,
   131  		},
   132  
   133  		{
   134  			src: uint16(42),
   135  			dst: types.Uint16Value(42),
   136  			err: nil,
   137  		},
   138  		{
   139  			src: func(v uint16) *uint16 { return &v }(42),
   140  			dst: types.OptionalValue(types.Uint16Value(42)),
   141  			err: nil,
   142  		},
   143  		{
   144  			src: func() *uint16 { return nil }(),
   145  			dst: types.NullValue(types.TypeUint16),
   146  			err: nil,
   147  		},
   148  
   149  		{
   150  			src: int32(42),
   151  			dst: types.Int32Value(42),
   152  			err: nil,
   153  		},
   154  		{
   155  			src: func(v int32) *int32 { return &v }(42),
   156  			dst: types.OptionalValue(types.Int32Value(42)),
   157  			err: nil,
   158  		},
   159  		{
   160  			src: func() *int32 { return nil }(),
   161  			dst: types.NullValue(types.TypeInt32),
   162  			err: nil,
   163  		},
   164  
   165  		{
   166  			src: uint32(42),
   167  			dst: types.Uint32Value(42),
   168  			err: nil,
   169  		},
   170  		{
   171  			src: func(v uint32) *uint32 { return &v }(42),
   172  			dst: types.OptionalValue(types.Uint32Value(42)),
   173  			err: nil,
   174  		},
   175  		{
   176  			src: func() *uint32 { return nil }(),
   177  			dst: types.NullValue(types.TypeUint32),
   178  			err: nil,
   179  		},
   180  
   181  		{
   182  			src: int64(42),
   183  			dst: types.Int64Value(42),
   184  			err: nil,
   185  		},
   186  		{
   187  			src: func(v int64) *int64 { return &v }(42),
   188  			dst: types.OptionalValue(types.Int64Value(42)),
   189  			err: nil,
   190  		},
   191  		{
   192  			src: func() *int64 { return nil }(),
   193  			dst: types.NullValue(types.TypeInt64),
   194  			err: nil,
   195  		},
   196  
   197  		{
   198  			src: uint64(42),
   199  			dst: types.Uint64Value(42),
   200  			err: nil,
   201  		},
   202  		{
   203  			src: func(v uint64) *uint64 { return &v }(42),
   204  			dst: types.OptionalValue(types.Uint64Value(42)),
   205  			err: nil,
   206  		},
   207  		{
   208  			src: func() *uint64 { return nil }(),
   209  			dst: types.NullValue(types.TypeUint64),
   210  			err: nil,
   211  		},
   212  
   213  		{
   214  			src: float32(42),
   215  			dst: types.FloatValue(42),
   216  			err: nil,
   217  		},
   218  		{
   219  			src: func(v float32) *float32 { return &v }(42),
   220  			dst: types.OptionalValue(types.FloatValue(42)),
   221  			err: nil,
   222  		},
   223  		{
   224  			src: func() *float32 { return nil }(),
   225  			dst: types.NullValue(types.TypeFloat),
   226  			err: nil,
   227  		},
   228  
   229  		{
   230  			src: float64(42),
   231  			dst: types.DoubleValue(42),
   232  			err: nil,
   233  		},
   234  		{
   235  			src: func(v float64) *float64 { return &v }(42),
   236  			dst: types.OptionalValue(types.DoubleValue(42)),
   237  			err: nil,
   238  		},
   239  		{
   240  			src: func() *float64 { return nil }(),
   241  			dst: types.NullValue(types.TypeDouble),
   242  			err: nil,
   243  		},
   244  
   245  		{
   246  			src: "test",
   247  			dst: types.TextValue("test"),
   248  			err: nil,
   249  		},
   250  		{
   251  			src: func(v string) *string { return &v }("test"),
   252  			dst: types.OptionalValue(types.TextValue("test")),
   253  			err: nil,
   254  		},
   255  		{
   256  			src: func() *string { return nil }(),
   257  			dst: types.NullValue(types.TypeText),
   258  			err: nil,
   259  		},
   260  
   261  		{
   262  			src: []byte("test"),
   263  			dst: types.BytesValue([]byte("test")),
   264  			err: nil,
   265  		},
   266  		{
   267  			src: func(v []byte) *[]byte { return &v }([]byte("test")),
   268  			dst: types.OptionalValue(types.BytesValue([]byte("test"))),
   269  			err: nil,
   270  		},
   271  		{
   272  			src: func() *[]byte { return nil }(),
   273  			dst: types.NullValue(types.TypeBytes),
   274  			err: nil,
   275  		},
   276  
   277  		{
   278  			src: []string{"test"},
   279  			dst: types.ListValue(types.TextValue("test")),
   280  			err: nil,
   281  		},
   282  		{
   283  			src: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   284  			dst: nil,
   285  			err: types.ErrIssue1501BadUUID,
   286  		},
   287  		{
   288  			src: func() *[16]byte { return nil }(),
   289  			dst: nil,
   290  			err: types.ErrIssue1501BadUUID,
   291  		},
   292  		{
   293  			src: func(v [16]byte) *[16]byte { return &v }([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}),
   294  			dst: nil,
   295  			err: types.ErrIssue1501BadUUID,
   296  		},
   297  		{
   298  			src: uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
   299  			dst: value.TextValue("01020304-0506-0708-090a-0b0c0d0e0f10"),
   300  			err: nil,
   301  		},
   302  		{
   303  			src: func(v uuid.UUID) *uuid.UUID { return &v }(uuid.UUID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}),
   304  			// uuid implemented driver.Valuer and doesn't set optional wrapper
   305  			dst: types.TextValue("01020304-0506-0708-090a-0b0c0d0e0f10"),
   306  			err: nil,
   307  		},
   308  		// https://github.com/ydb-platform/ydb-go-sdk/issues/1515
   309  		//{
   310  		//	src: func() *uuid.UUID { return nil }(),
   311  		//	dst: nil,
   312  		//	err: nil,
   313  		//},
   314  		{
   315  			src: time.Unix(42, 43),
   316  			dst: types.TimestampValueFromTime(time.Unix(42, 43)),
   317  			err: nil,
   318  		},
   319  		{
   320  			src: func(v time.Time) *time.Time { return &v }(time.Unix(42, 43)),
   321  			dst: types.OptionalValue(types.TimestampValueFromTime(time.Unix(42, 43))),
   322  			err: nil,
   323  		},
   324  		{
   325  			src: func() *time.Time { return nil }(),
   326  			dst: types.NullValue(types.TypeTimestamp),
   327  			err: nil,
   328  		},
   329  
   330  		{
   331  			src: time.Duration(42),
   332  			dst: types.IntervalValueFromDuration(time.Duration(42)),
   333  			err: nil,
   334  		},
   335  		{
   336  			src: func(v time.Duration) *time.Duration { return &v }(time.Duration(42)),
   337  			dst: types.OptionalValue(types.IntervalValueFromDuration(time.Duration(42))),
   338  			err: nil,
   339  		},
   340  		{
   341  			src: func() *time.Duration { return nil }(),
   342  			dst: types.NullValue(types.TypeInterval),
   343  			err: nil,
   344  		},
   345  	} {
   346  		t.Run(fmt.Sprintf("%T(%v)", tt.src, tt.src), func(t *testing.T) {
   347  			dst, err := toValue(tt.src)
   348  			if tt.err != nil {
   349  				require.ErrorIs(t, err, tt.err)
   350  			} else {
   351  				require.Equal(t, tt.dst, dst)
   352  			}
   353  		})
   354  	}
   355  }
   356  
   357  func named(name string, value interface{}) driver.NamedValue {
   358  	return driver.NamedValue{
   359  		Name:  name,
   360  		Value: value,
   361  	}
   362  }
   363  
   364  func TestYdbParam(t *testing.T) {
   365  	for _, tt := range []struct {
   366  		src interface{}
   367  		dst *params.Parameter
   368  		err error
   369  	}{
   370  		{
   371  			src: params.Named("$a", types.Int32Value(42)),
   372  			dst: params.Named("$a", types.Int32Value(42)),
   373  			err: nil,
   374  		},
   375  		{
   376  			src: named("a", int(42)),
   377  			dst: params.Named("$a", types.Int32Value(42)),
   378  			err: nil,
   379  		},
   380  		{
   381  			src: named("$a", int(42)),
   382  			dst: params.Named("$a", types.Int32Value(42)),
   383  			err: nil,
   384  		},
   385  		{
   386  			src: named("a", uint(42)),
   387  			dst: params.Named("$a", types.Uint32Value(42)),
   388  			err: nil,
   389  		},
   390  		{
   391  			src: driver.NamedValue{Value: uint(42)},
   392  			dst: nil,
   393  			err: errUnnamedParam,
   394  		},
   395  	} {
   396  		t.Run("", func(t *testing.T) {
   397  			dst, err := toYdbParam("", tt.src)
   398  			if tt.err != nil {
   399  				require.ErrorIs(t, err, tt.err)
   400  			} else {
   401  				require.Equal(t, tt.dst, dst)
   402  			}
   403  		})
   404  	}
   405  }
   406  
   407  func TestArgsToParams(t *testing.T) {
   408  	for _, tt := range []struct {
   409  		args   []interface{}
   410  		params []*params.Parameter
   411  		err    error
   412  	}{
   413  		{
   414  			args:   []interface{}{},
   415  			params: []*params.Parameter{},
   416  			err:    nil,
   417  		},
   418  		{
   419  			args: []interface{}{
   420  				1, uint64(2), "3",
   421  			},
   422  			params: []*params.Parameter{
   423  				params.Named("$p0", types.Int32Value(1)),
   424  				params.Named("$p1", types.Uint64Value(2)),
   425  				params.Named("$p2", types.TextValue("3")),
   426  			},
   427  			err: nil,
   428  		},
   429  		{
   430  			args: []interface{}{
   431  				table.NewQueryParameters(
   432  					params.Named("$p0", types.Int32Value(1)),
   433  					params.Named("$p1", types.Uint64Value(2)),
   434  					params.Named("$p2", types.TextValue("3")),
   435  				),
   436  				table.NewQueryParameters(
   437  					params.Named("$p0", types.Int32Value(1)),
   438  					params.Named("$p1", types.Uint64Value(2)),
   439  					params.Named("$p2", types.TextValue("3")),
   440  				),
   441  			},
   442  			err: errMultipleQueryParameters,
   443  		},
   444  		{
   445  			args: []interface{}{
   446  				params.Named("$p0", types.Int32Value(1)),
   447  				params.Named("$p1", types.Uint64Value(2)),
   448  				params.Named("$p2", types.TextValue("3")),
   449  			},
   450  			params: []*params.Parameter{
   451  				params.Named("$p0", types.Int32Value(1)),
   452  				params.Named("$p1", types.Uint64Value(2)),
   453  				params.Named("$p2", types.TextValue("3")),
   454  			},
   455  			err: nil,
   456  		},
   457  		{
   458  			args: []interface{}{
   459  				sql.Named("$p0", types.Int32Value(1)),
   460  				sql.Named("$p1", types.Uint64Value(2)),
   461  				sql.Named("$p2", types.TextValue("3")),
   462  			},
   463  			params: []*params.Parameter{
   464  				params.Named("$p0", types.Int32Value(1)),
   465  				params.Named("$p1", types.Uint64Value(2)),
   466  				params.Named("$p2", types.TextValue("3")),
   467  			},
   468  			err: nil,
   469  		},
   470  		{
   471  			args: []interface{}{
   472  				driver.NamedValue{Name: "$p0", Value: types.Int32Value(1)},
   473  				driver.NamedValue{Name: "$p1", Value: types.Uint64Value(2)},
   474  				driver.NamedValue{Name: "$p2", Value: types.TextValue("3")},
   475  			},
   476  			params: []*params.Parameter{
   477  				params.Named("$p0", types.Int32Value(1)),
   478  				params.Named("$p1", types.Uint64Value(2)),
   479  				params.Named("$p2", types.TextValue("3")),
   480  			},
   481  			err: nil,
   482  		},
   483  		{
   484  			args: []interface{}{
   485  				driver.NamedValue{Value: params.Named("$p0", types.Int32Value(1))},
   486  				driver.NamedValue{Value: params.Named("$p1", types.Uint64Value(2))},
   487  				driver.NamedValue{Value: params.Named("$p2", types.TextValue("3"))},
   488  			},
   489  			params: []*params.Parameter{
   490  				params.Named("$p0", types.Int32Value(1)),
   491  				params.Named("$p1", types.Uint64Value(2)),
   492  				params.Named("$p2", types.TextValue("3")),
   493  			},
   494  			err: nil,
   495  		},
   496  		{
   497  			args: []interface{}{
   498  				driver.NamedValue{Value: 1},
   499  				driver.NamedValue{Value: uint64(2)},
   500  				driver.NamedValue{Value: "3"},
   501  			},
   502  			params: []*params.Parameter{
   503  				params.Named("$p0", types.Int32Value(1)),
   504  				params.Named("$p1", types.Uint64Value(2)),
   505  				params.Named("$p2", types.TextValue("3")),
   506  			},
   507  			err: nil,
   508  		},
   509  		{
   510  			args: []interface{}{
   511  				driver.NamedValue{Value: table.NewQueryParameters(
   512  					params.Named("$p0", types.Int32Value(1)),
   513  					params.Named("$p1", types.Uint64Value(2)),
   514  					params.Named("$p2", types.TextValue("3")),
   515  				)},
   516  			},
   517  			params: []*params.Parameter{
   518  				params.Named("$p0", types.Int32Value(1)),
   519  				params.Named("$p1", types.Uint64Value(2)),
   520  				params.Named("$p2", types.TextValue("3")),
   521  			},
   522  			err: nil,
   523  		},
   524  		{
   525  			args: []interface{}{
   526  				driver.NamedValue{Value: table.NewQueryParameters(
   527  					params.Named("$p0", types.Int32Value(1)),
   528  					params.Named("$p1", types.Uint64Value(2)),
   529  					params.Named("$p2", types.TextValue("3")),
   530  				)},
   531  				driver.NamedValue{Value: params.Named("$p1", types.Uint64Value(2))},
   532  				driver.NamedValue{Value: params.Named("$p2", types.TextValue("3"))},
   533  			},
   534  			err: errMultipleQueryParameters,
   535  		},
   536  	} {
   537  		t.Run("", func(t *testing.T) {
   538  			params, err := Params(tt.args...)
   539  			if tt.err != nil {
   540  				require.ErrorIs(t, err, tt.err)
   541  			} else {
   542  				require.NoError(t, err)
   543  				require.Equal(t, tt.params, params)
   544  			}
   545  		})
   546  	}
   547  }