github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/sugar/params_test.go (about) 1 package sugar 2 3 import ( 4 "database/sql" 5 "sort" 6 "strings" 7 "testing" 8 "time" 9 10 "github.com/stretchr/testify/require" 11 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/bind" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" 14 "github.com/ydb-platform/ydb-go-sdk/v3/table" 15 "github.com/ydb-platform/ydb-go-sdk/v3/table/types" 16 "github.com/ydb-platform/ydb-go-sdk/v3/testutil" 17 ) 18 19 func TestGenerateDeclareSection(t *testing.T) { 20 splitDeclares := func(declaresSection string) (declares []string) { 21 for _, s := range strings.Split(declaresSection, ";") { 22 s = strings.TrimSpace(s) 23 if s != "" { 24 declares = append(declares, s) 25 } 26 } 27 sort.Strings(declares) 28 29 return declares 30 } 31 for _, tt := range []struct { 32 params *table.QueryParameters 33 declare string 34 }{ 35 { 36 params: table.NewQueryParameters( 37 table.ValueParam( 38 "$values", 39 types.ListValue( 40 types.Uint64Value(1), 41 types.Uint64Value(2), 42 types.Uint64Value(3), 43 types.Uint64Value(4), 44 types.Uint64Value(5), 45 ), 46 ), 47 ), 48 declare: ` 49 DECLARE $values AS List<Uint64>; 50 `, 51 }, 52 { 53 params: table.NewQueryParameters( 54 table.ValueParam( 55 "$delta", 56 types.IntervalValueFromDuration(time.Hour), 57 ), 58 ), 59 declare: ` 60 DECLARE $delta AS Interval; 61 `, 62 }, 63 { 64 params: table.NewQueryParameters( 65 table.ValueParam("$ts", types.TimestampValueFromTime(time.Now())), 66 ), 67 declare: ` 68 DECLARE $ts AS Timestamp; 69 `, 70 }, 71 { 72 params: table.NewQueryParameters( 73 table.ValueParam("$a", types.BoolValue(true)), 74 table.ValueParam("$b", types.Int64Value(123)), 75 table.ValueParam("$c", types.OptionalValue(types.TextValue("test"))), 76 ), 77 declare: ` 78 DECLARE $a AS Bool; 79 DECLARE $b AS Int64; 80 DECLARE $c AS Optional<Utf8>; 81 `, 82 }, 83 { 84 params: table.NewQueryParameters( 85 table.ValueParam("$a", types.BoolValue(true)), 86 table.ValueParam("b", types.Int64Value(123)), 87 table.ValueParam("c", types.OptionalValue(types.TextValue("test"))), 88 ), 89 declare: ` 90 DECLARE $a AS Bool; 91 DECLARE $b AS Int64; 92 DECLARE $c AS Optional<Utf8>; 93 `, 94 }, 95 } { 96 t.Run("", func(t *testing.T) { 97 declares, err := GenerateDeclareSection(tt.params) 98 require.NoError(t, err) 99 got := splitDeclares(declares) 100 want := splitDeclares(tt.declare) 101 if len(got) != len(want) { 102 t.Errorf("len(got) = %v, len(want) = %v", len(got), len(want)) 103 } else { 104 for i := range got { 105 if strings.TrimSpace(got[i]) != strings.TrimSpace(want[i]) { 106 t.Errorf( 107 "unexpected generation of declare section:\n%v\n\nwant:\n%v", 108 strings.Join(got, ";\n"), 109 strings.Join(want, ";\n"), 110 ) 111 } 112 } 113 } 114 }) 115 } 116 } 117 118 func TestGenerateDeclareSection_ParameterOption(t *testing.T) { 119 b := testutil.QueryBind(bind.AutoDeclare{}) 120 getDeclares := func(declaresSection string) (declares []string) { 121 for _, s := range strings.Split(declaresSection, "\n") { 122 s = strings.TrimSpace(s) 123 if s != "" && !strings.HasPrefix(s, "--") { 124 declares = append(declares, strings.TrimRight(s, ";")) 125 } 126 } 127 sort.Strings(declares) 128 129 return declares 130 } 131 for _, tt := range []struct { 132 params []interface{} 133 declares []string 134 }{ 135 { 136 params: []interface{}{ 137 table.ValueParam( 138 "$values", 139 types.ListValue( 140 types.Uint64Value(1), 141 types.Uint64Value(2), 142 types.Uint64Value(3), 143 types.Uint64Value(4), 144 types.Uint64Value(5), 145 ), 146 ), 147 }, 148 declares: []string{ 149 "DECLARE $values AS List<Uint64>", 150 }, 151 }, 152 { 153 params: []interface{}{ 154 table.ValueParam( 155 "$delta", 156 types.IntervalValueFromDuration(time.Hour), 157 ), 158 }, 159 declares: []string{ 160 "DECLARE $delta AS Interval", 161 }, 162 }, 163 { 164 params: []interface{}{ 165 table.ValueParam( 166 "$ts", 167 types.TimestampValueFromTime(time.Now()), 168 ), 169 }, 170 declares: []string{ 171 "DECLARE $ts AS Timestamp", 172 }, 173 }, 174 { 175 params: []interface{}{ 176 table.ValueParam( 177 "$a", 178 types.BoolValue(true), 179 ), 180 table.ValueParam( 181 "$b", 182 types.Int64Value(123), 183 ), 184 table.ValueParam( 185 "$c", 186 types.OptionalValue(types.TextValue("test")), 187 ), 188 }, 189 declares: []string{ 190 "DECLARE $a AS Bool", 191 "DECLARE $b AS Int64", 192 "DECLARE $c AS Optional<Utf8>", 193 }, 194 }, 195 { 196 params: []interface{}{ 197 table.ValueParam( 198 "$a", 199 types.BoolValue(true), 200 ), 201 table.ValueParam( 202 "b", 203 types.Int64Value(123), 204 ), 205 table.ValueParam( 206 "c", 207 types.OptionalValue(types.TextValue("test")), 208 ), 209 }, 210 declares: []string{ 211 "DECLARE $a AS Bool", 212 "DECLARE $b AS Int64", 213 "DECLARE $c AS Optional<Utf8>", 214 }, 215 }, 216 } { 217 t.Run("", func(t *testing.T) { 218 yql, _, err := b.RewriteQuery("", tt.params...) 219 require.NoError(t, err) 220 require.Equal(t, tt.declares, getDeclares(yql)) 221 }) 222 } 223 } 224 225 func TestGenerateDeclareSection_NamedArg(t *testing.T) { 226 b := testutil.QueryBind(bind.AutoDeclare{}) 227 getDeclares := func(declaresSection string) (declares []string) { 228 for _, s := range strings.Split(declaresSection, "\n") { 229 s = strings.TrimSpace(s) 230 if s != "" && !strings.HasPrefix(s, "--") { 231 declares = append(declares, strings.TrimRight(s, ";")) 232 } 233 } 234 sort.Strings(declares) 235 236 return declares 237 } 238 for _, tt := range []struct { 239 params []interface{} 240 declares []string 241 }{ 242 { 243 params: []interface{}{ 244 sql.Named( 245 "values", 246 types.ListValue( 247 types.Uint64Value(1), 248 types.Uint64Value(2), 249 types.Uint64Value(3), 250 types.Uint64Value(4), 251 types.Uint64Value(5), 252 ), 253 ), 254 }, 255 declares: []string{ 256 "DECLARE $values AS List<Uint64>", 257 }, 258 }, 259 { 260 params: []interface{}{ 261 sql.Named( 262 "delta", 263 types.IntervalValueFromDuration(time.Hour), 264 ), 265 }, 266 declares: []string{ 267 "DECLARE $delta AS Interval", 268 }, 269 }, 270 { 271 params: []interface{}{ 272 sql.Named( 273 "ts", 274 types.TimestampValueFromTime(time.Now()), 275 ), 276 }, 277 declares: []string{ 278 "DECLARE $ts AS Timestamp", 279 }, 280 }, 281 { 282 params: []interface{}{ 283 sql.Named( 284 "a", 285 types.BoolValue(true), 286 ), 287 sql.Named( 288 "b", 289 types.Int64Value(123), 290 ), 291 sql.Named( 292 "c", 293 types.OptionalValue(types.TextValue("test")), 294 ), 295 }, 296 declares: []string{ 297 "DECLARE $a AS Bool", 298 "DECLARE $b AS Int64", 299 "DECLARE $c AS Optional<Utf8>", 300 }, 301 }, 302 { 303 params: []interface{}{ 304 sql.Named( 305 "a", 306 types.BoolValue(true), 307 ), 308 sql.Named( 309 "b", 310 types.Int64Value(123), 311 ), 312 sql.Named( 313 "c", 314 types.OptionalValue(types.TextValue("test")), 315 ), 316 }, 317 declares: []string{ 318 "DECLARE $a AS Bool", 319 "DECLARE $b AS Int64", 320 "DECLARE $c AS Optional<Utf8>", 321 }, 322 }, 323 324 { 325 params: []interface{}{ 326 sql.Named("delta", time.Hour), 327 }, 328 declares: []string{ 329 "DECLARE $delta AS Interval", 330 }, 331 }, 332 { 333 params: []interface{}{ 334 sql.Named("ts", time.Now()), 335 }, 336 declares: []string{ 337 "DECLARE $ts AS Timestamp", 338 }, 339 }, 340 { 341 params: []interface{}{ 342 sql.Named("$a", true), 343 sql.Named("$b", int64(123)), 344 sql.Named("$c", func(s string) *string { return &s }("test")), 345 }, 346 declares: []string{ 347 "DECLARE $a AS Bool", 348 "DECLARE $b AS Int64", 349 "DECLARE $c AS Optional<Utf8>", 350 }, 351 }, 352 { 353 params: []interface{}{ 354 sql.Named("$a", func(b bool) *bool { return &b }(true)), 355 sql.Named("b", func(i int64) *int64 { return &i }(123)), 356 sql.Named("c", func(s string) *string { return &s }("test")), 357 }, 358 declares: []string{ 359 "DECLARE $a AS Optional<Bool>", 360 "DECLARE $b AS Optional<Int64>", 361 "DECLARE $c AS Optional<Utf8>", 362 }, 363 }, 364 } { 365 t.Run("", func(t *testing.T) { 366 yql, _, err := b.RewriteQuery("", tt.params...) 367 require.NoError(t, err) 368 require.Equal(t, tt.declares, getDeclares(yql)) 369 }) 370 } 371 } 372 373 func TestToYdbParam(t *testing.T) { 374 for _, tt := range []struct { 375 name string 376 param sql.NamedArg 377 ydbParam table.ParameterOption 378 err error 379 }{ 380 { 381 name: xtest.CurrentFileLine(), 382 param: sql.Named("a", "b"), 383 ydbParam: table.ValueParam("$a", types.TextValue("b")), 384 err: nil, 385 }, 386 { 387 name: xtest.CurrentFileLine(), 388 param: sql.Named("a", 123), 389 ydbParam: table.ValueParam("$a", types.Int32Value(123)), 390 err: nil, 391 }, 392 { 393 name: xtest.CurrentFileLine(), 394 param: sql.Named("a", types.OptionalValue(types.TupleValue( 395 types.BytesValue([]byte("test")), 396 types.TextValue("test"), 397 types.Uint64Value(123), 398 ))), 399 ydbParam: table.ValueParam("$a", types.OptionalValue(types.TupleValue( 400 types.BytesValue([]byte("test")), 401 types.TextValue("test"), 402 types.Uint64Value(123), 403 ))), 404 err: nil, 405 }, 406 } { 407 t.Run(tt.name, func(t *testing.T) { 408 ydbParam, err := ToYdbParam(tt.param) 409 if tt.err != nil { 410 require.Error(t, err) 411 } else { 412 require.NoError(t, err) 413 require.Equal(t, tt.ydbParam, ydbParam) 414 } 415 }) 416 } 417 }