vitess.io/vitess@v0.16.2/go/vt/sqlparser/normalizer_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "bytes" 21 "fmt" 22 "math/rand" 23 "reflect" 24 "regexp" 25 "strconv" 26 "testing" 27 28 "github.com/stretchr/testify/assert" 29 30 "github.com/stretchr/testify/require" 31 32 "vitess.io/vitess/go/sqltypes" 33 querypb "vitess.io/vitess/go/vt/proto/query" 34 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 35 "vitess.io/vitess/go/vt/vterrors" 36 ) 37 38 func TestNormalize(t *testing.T) { 39 prefix := "bv" 40 testcases := []struct { 41 in string 42 outstmt string 43 outbv map[string]*querypb.BindVariable 44 }{{ 45 // str val 46 in: "select * from t where foobar = 'aa'", 47 outstmt: "select * from t where foobar = :foobar", 48 outbv: map[string]*querypb.BindVariable{ 49 "foobar": sqltypes.StringBindVariable("aa"), 50 }, 51 }, { 52 // placeholder 53 in: "select * from t where col=?", 54 outstmt: "select * from t where col = :v1", 55 outbv: map[string]*querypb.BindVariable{}, 56 }, { 57 // qualified table name 58 in: "select * from `t` where col=?", 59 outstmt: "select * from t where col = :v1", 60 outbv: map[string]*querypb.BindVariable{}, 61 }, { 62 // str val in select 63 in: "select 'aa' from t", 64 outstmt: "select :bv1 from t", 65 outbv: map[string]*querypb.BindVariable{ 66 "bv1": sqltypes.StringBindVariable("aa"), 67 }, 68 }, { 69 // int val 70 in: "select * from t where foobar = 1", 71 outstmt: "select * from t where foobar = :foobar", 72 outbv: map[string]*querypb.BindVariable{ 73 "foobar": sqltypes.Int64BindVariable(1), 74 }, 75 }, { 76 // float val 77 in: "select * from t where foobar = 1.2", 78 outstmt: "select * from t where foobar = :foobar", 79 outbv: map[string]*querypb.BindVariable{ 80 "foobar": sqltypes.DecimalBindVariable(1.2), 81 }, 82 }, { 83 // multiple vals 84 in: "select * from t where foo = 1.2 and bar = 2", 85 outstmt: "select * from t where foo = :foo and bar = :bar", 86 outbv: map[string]*querypb.BindVariable{ 87 "foo": sqltypes.DecimalBindVariable(1.2), 88 "bar": sqltypes.Int64BindVariable(2), 89 }, 90 }, { 91 // bv collision 92 in: "select * from t where foo = :bar and bar = 12", 93 outstmt: "select * from t where foo = :bar and bar = :bar1", 94 outbv: map[string]*querypb.BindVariable{ 95 "bar1": sqltypes.Int64BindVariable(12), 96 }, 97 }, { 98 // val reuse 99 in: "select * from t where foo = 1 and bar = 1", 100 outstmt: "select * from t where foo = :foo and bar = :foo", 101 outbv: map[string]*querypb.BindVariable{ 102 "foo": sqltypes.Int64BindVariable(1), 103 }, 104 }, { 105 // ints and strings are different 106 in: "select * from t where foo = 1 and bar = '1'", 107 outstmt: "select * from t where foo = :foo and bar = :bar", 108 outbv: map[string]*querypb.BindVariable{ 109 "foo": sqltypes.Int64BindVariable(1), 110 "bar": sqltypes.StringBindVariable("1"), 111 }, 112 }, { 113 // val should not be reused for non-select statements 114 in: "insert into a values(1, 1)", 115 outstmt: "insert into a values (:bv1, :bv2)", 116 outbv: map[string]*querypb.BindVariable{ 117 "bv1": sqltypes.Int64BindVariable(1), 118 "bv2": sqltypes.Int64BindVariable(1), 119 }, 120 }, { 121 // val should be reused only in subqueries of DMLs 122 in: "update a set v1=(select 5 from t), v2=5, v3=(select 5 from t), v4=5", 123 outstmt: "update a set v1 = (select :bv1 from t), v2 = :bv1, v3 = (select :bv1 from t), v4 = :bv1", 124 outbv: map[string]*querypb.BindVariable{ 125 "bv1": sqltypes.Int64BindVariable(5), 126 }, 127 }, { 128 // list vars should work for DMLs also 129 in: "update a set v1=5 where v2 in (1, 4, 5)", 130 outstmt: "update a set v1 = :v1 where v2 in ::bv1", 131 outbv: map[string]*querypb.BindVariable{ 132 "v1": sqltypes.Int64BindVariable(5), 133 "bv1": sqltypes.TestBindVariable([]any{1, 4, 5}), 134 }, 135 }, { 136 // Hex number values should work for selects 137 in: "select * from t where foo = 0x1234", 138 outstmt: "select * from t where foo = :foo", 139 outbv: map[string]*querypb.BindVariable{ 140 "foo": sqltypes.HexNumBindVariable([]byte("0x1234")), 141 }, 142 }, { 143 // Hex encoded string values should work for selects 144 in: "select * from t where foo = x'7b7d'", 145 outstmt: "select * from t where foo = :foo", 146 outbv: map[string]*querypb.BindVariable{ 147 "foo": sqltypes.HexValBindVariable([]byte("x'7b7d'")), 148 }, 149 }, { 150 // Ensure that hex notation bind vars work with collation based conversions 151 in: "select convert(x'7b7d' using utf8mb4) from dual", 152 outstmt: "select convert(:bv1 using utf8mb4) from dual", 153 outbv: map[string]*querypb.BindVariable{ 154 "bv1": sqltypes.HexValBindVariable([]byte("x'7b7d'")), 155 }, 156 }, { 157 // Hex number values should work for DMLs 158 in: "update a set foo = 0x12", 159 outstmt: "update a set foo = :foo", 160 outbv: map[string]*querypb.BindVariable{ 161 "foo": sqltypes.HexNumBindVariable([]byte("0x12")), 162 }, 163 }, { 164 // Bin values work fine 165 in: "select * from t where foo = b'11'", 166 outstmt: "select * from t where foo = :foo", 167 outbv: map[string]*querypb.BindVariable{ 168 "foo": sqltypes.HexNumBindVariable([]byte("0x3")), 169 }, 170 }, { 171 // Bin value does not convert for DMLs 172 in: "update a set v1 = b'11'", 173 outstmt: "update a set v1 = :v1", 174 outbv: map[string]*querypb.BindVariable{ 175 "v1": sqltypes.HexNumBindVariable([]byte("0x3")), 176 }, 177 }, { 178 // ORDER BY column_position 179 in: "select a, b from t order by 1 asc", 180 outstmt: "select a, b from t order by 1 asc", 181 outbv: map[string]*querypb.BindVariable{}, 182 }, { 183 // GROUP BY column_position 184 in: "select a, b from t group by 1", 185 outstmt: "select a, b from t group by 1", 186 outbv: map[string]*querypb.BindVariable{}, 187 }, { 188 // ORDER BY with literal inside complex expression 189 in: "select a, b from t order by field(a,1,2,3) asc", 190 outstmt: "select a, b from t order by field(a, :bv1, :bv2, :bv3) asc", 191 outbv: map[string]*querypb.BindVariable{ 192 "bv1": sqltypes.Int64BindVariable(1), 193 "bv2": sqltypes.Int64BindVariable(2), 194 "bv3": sqltypes.Int64BindVariable(3), 195 }, 196 }, { 197 // ORDER BY variable 198 in: "select a, b from t order by c asc", 199 outstmt: "select a, b from t order by c asc", 200 outbv: map[string]*querypb.BindVariable{}, 201 }, { 202 // Values up to len 256 will reuse. 203 in: fmt.Sprintf("select * from t where foo = '%256s' and bar = '%256s'", "a", "a"), 204 outstmt: "select * from t where foo = :foo and bar = :foo", 205 outbv: map[string]*querypb.BindVariable{ 206 "foo": sqltypes.StringBindVariable(fmt.Sprintf("%256s", "a")), 207 }, 208 }, { 209 // Values greater than len 256 will not reuse. 210 in: fmt.Sprintf("select * from t where foo = '%257s' and bar = '%257s'", "b", "b"), 211 outstmt: "select * from t where foo = :foo and bar = :bar", 212 outbv: map[string]*querypb.BindVariable{ 213 "foo": sqltypes.StringBindVariable(fmt.Sprintf("%257s", "b")), 214 "bar": sqltypes.StringBindVariable(fmt.Sprintf("%257s", "b")), 215 }, 216 }, { 217 // bad int 218 in: "select * from t where v1 = 12345678901234567890", 219 outstmt: "select * from t where v1 = 12345678901234567890", 220 outbv: map[string]*querypb.BindVariable{}, 221 }, { 222 // comparison with no vals 223 in: "select * from t where v1 = v2", 224 outstmt: "select * from t where v1 = v2", 225 outbv: map[string]*querypb.BindVariable{}, 226 }, { 227 // IN clause with existing bv 228 in: "select * from t where v1 in ::list", 229 outstmt: "select * from t where v1 in ::list", 230 outbv: map[string]*querypb.BindVariable{}, 231 }, { 232 // IN clause with non-val values 233 in: "select * from t where v1 in (1, a)", 234 outstmt: "select * from t where v1 in (:bv1, a)", 235 outbv: map[string]*querypb.BindVariable{ 236 "bv1": sqltypes.Int64BindVariable(1), 237 }, 238 }, { 239 // IN clause with vals 240 in: "select * from t where v1 in (1, '2')", 241 outstmt: "select * from t where v1 in ::bv1", 242 outbv: map[string]*querypb.BindVariable{ 243 "bv1": sqltypes.TestBindVariable([]any{1, "2"}), 244 }, 245 }, { 246 // EXPLAIN queries 247 in: "explain select * from t where v1 in (1, '2')", 248 outstmt: "explain select * from t where v1 in ::bv1", 249 outbv: map[string]*querypb.BindVariable{ 250 "bv1": sqltypes.TestBindVariable([]any{1, "2"}), 251 }, 252 }, { 253 // NOT IN clause 254 in: "select * from t where v1 not in (1, '2')", 255 outstmt: "select * from t where v1 not in ::bv1", 256 outbv: map[string]*querypb.BindVariable{ 257 "bv1": sqltypes.TestBindVariable([]any{1, "2"}), 258 }, 259 }, { 260 // Do not normalize cast/convert types 261 in: `select CAST("test" AS CHAR(60))`, 262 outstmt: `select cast(:bv1 as CHAR(60)) from dual`, 263 outbv: map[string]*querypb.BindVariable{ 264 "bv1": sqltypes.StringBindVariable("test"), 265 }, 266 }, { 267 // insert syntax 268 in: "insert into a (v1, v2, v3) values (1, '2', 3)", 269 outstmt: "insert into a(v1, v2, v3) values (:bv1, :bv2, :bv3)", 270 outbv: map[string]*querypb.BindVariable{ 271 "bv1": sqltypes.Int64BindVariable(1), 272 "bv2": sqltypes.StringBindVariable("2"), 273 "bv3": sqltypes.Int64BindVariable(3), 274 }, 275 }, { 276 // BitVal should also be normalized 277 in: `select b'1', 0b01, b'1010', 0b1111111`, 278 outstmt: `select :bv1, :bv2, :bv3, :bv4 from dual`, 279 outbv: map[string]*querypb.BindVariable{ 280 "bv1": sqltypes.HexNumBindVariable([]byte("0x1")), 281 "bv2": sqltypes.HexNumBindVariable([]byte("0x1")), 282 "bv3": sqltypes.HexNumBindVariable([]byte("0xa")), 283 "bv4": sqltypes.HexNumBindVariable([]byte("0x7f")), 284 }, 285 }, { 286 // DateVal should also be normalized 287 in: `select date'2022-08-06'`, 288 outstmt: `select :bv1 from dual`, 289 outbv: map[string]*querypb.BindVariable{ 290 "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Date, []byte("2022-08-06"))), 291 }, 292 }, { 293 // TimeVal should also be normalized 294 in: `select time'17:05:12'`, 295 outstmt: `select :bv1 from dual`, 296 outbv: map[string]*querypb.BindVariable{ 297 "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Time, []byte("17:05:12"))), 298 }, 299 }, { 300 // TimestampVal should also be normalized 301 in: `select timestamp'2022-08-06 17:05:12'`, 302 outstmt: `select :bv1 from dual`, 303 outbv: map[string]*querypb.BindVariable{ 304 "bv1": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Datetime, []byte("2022-08-06 17:05:12"))), 305 }, 306 }, { 307 // TimestampVal should also be normalized 308 in: `explain select comms_by_companies.* from comms_by_companies where comms_by_companies.id = 'rjve634shXzaavKHbAH16ql6OrxJ' limit 1,1`, 309 outstmt: `explain select comms_by_companies.* from comms_by_companies where comms_by_companies.id = :comms_by_companies_id limit :bv1, :bv2`, 310 outbv: map[string]*querypb.BindVariable{ 311 "bv1": sqltypes.Int64BindVariable(1), 312 "bv2": sqltypes.Int64BindVariable(1), 313 "comms_by_companies_id": sqltypes.StringBindVariable("rjve634shXzaavKHbAH16ql6OrxJ"), 314 }, 315 }, { 316 // Int leading with zero should also be normalized 317 in: `select * from t where zipcode = 01001900`, 318 outstmt: `select * from t where zipcode = :zipcode`, 319 outbv: map[string]*querypb.BindVariable{ 320 "zipcode": sqltypes.ValueBindVariable(sqltypes.MakeTrusted(sqltypes.Int64, []byte("01001900"))), 321 }, 322 }, { 323 // literals in limit and offset should not reuse bindvars 324 in: `select * from t where id = 10 limit 10 offset 10`, 325 outstmt: `select * from t where id = :id limit :bv1, :bv2`, 326 outbv: map[string]*querypb.BindVariable{ 327 "bv1": sqltypes.Int64BindVariable(10), 328 "bv2": sqltypes.Int64BindVariable(10), 329 "id": sqltypes.Int64BindVariable(10), 330 }, 331 }, { 332 // we don't want to replace literals on the select expressions of a derived table 333 // these expressions can be referenced from the outside, 334 // and changing them to bindvars can change the meaning of the query 335 // example of problematic query: select tmp.`1` from (select 1) as tmp 336 in: `select * from (select 12) as t`, 337 outstmt: `select * from (select 12 from dual) as t`, 338 outbv: map[string]*querypb.BindVariable{}, 339 }} 340 for _, tc := range testcases { 341 t.Run(tc.in, func(t *testing.T) { 342 stmt, err := Parse(tc.in) 343 require.NoError(t, err) 344 known := GetBindvars(stmt) 345 bv := make(map[string]*querypb.BindVariable) 346 require.NoError(t, Normalize(stmt, NewReservedVars(prefix, known), bv)) 347 assert.Equal(t, tc.outstmt, String(stmt)) 348 assert.Equal(t, tc.outbv, bv) 349 }) 350 } 351 } 352 353 func TestNormalizeInvalidDates(t *testing.T) { 354 testcases := []struct { 355 in string 356 err error 357 }{{ 358 in: "select date'foo'", 359 err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect DATE value: '%s'", "foo"), 360 }, { 361 in: "select time'foo'", 362 err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect TIME value: '%s'", "foo"), 363 }, { 364 in: "select timestamp'foo'", 365 err: vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.WrongValue, "incorrect DATETIME value: '%s'", "foo"), 366 }} 367 for _, tc := range testcases { 368 t.Run(tc.in, func(t *testing.T) { 369 stmt, err := Parse(tc.in) 370 require.NoError(t, err) 371 known := GetBindvars(stmt) 372 bv := make(map[string]*querypb.BindVariable) 373 require.EqualError(t, Normalize(stmt, NewReservedVars("bv", known), bv), tc.err.Error()) 374 }) 375 } 376 } 377 378 func TestNormalizeValidSQL(t *testing.T) { 379 for _, tcase := range validSQL { 380 t.Run(tcase.input, func(t *testing.T) { 381 if tcase.partialDDL || tcase.ignoreNormalizerTest { 382 return 383 } 384 tree, err := Parse(tcase.input) 385 require.NoError(t, err, tcase.input) 386 // Skip the test for the queries that do not run the normalizer 387 if !CanNormalize(tree) { 388 return 389 } 390 bv := make(map[string]*querypb.BindVariable) 391 known := make(BindVars) 392 err = Normalize(tree, NewReservedVars("vtg", known), bv) 393 require.NoError(t, err) 394 normalizerOutput := String(tree) 395 if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { 396 return 397 } 398 _, err = Parse(normalizerOutput) 399 require.NoError(t, err, normalizerOutput) 400 }) 401 } 402 } 403 404 func TestGetBindVars(t *testing.T) { 405 stmt, err := Parse("select * from t where :v1 = :v2 and :v2 = :v3 and :v4 in ::v5") 406 if err != nil { 407 t.Fatal(err) 408 } 409 got := GetBindvars(stmt) 410 want := map[string]struct{}{ 411 "v1": {}, 412 "v2": {}, 413 "v3": {}, 414 "v4": {}, 415 "v5": {}, 416 } 417 if !reflect.DeepEqual(got, want) { 418 t.Errorf("GetBindVars: %v, want: %v", got, want) 419 } 420 } 421 422 /* 423 Skipping ColName, TableName: 424 BenchmarkNormalize-8 1000000 2205 ns/op 821 B/op 27 allocs/op 425 Prior to skip: 426 BenchmarkNormalize-8 500000 3620 ns/op 1461 B/op 55 allocs/op 427 */ 428 func BenchmarkNormalize(b *testing.B) { 429 sql := "select 'abcd', 20, 30.0, eid from a where 1=eid and name='3'" 430 ast, reservedVars, err := Parse2(sql) 431 if err != nil { 432 b.Fatal(err) 433 } 434 for i := 0; i < b.N; i++ { 435 require.NoError(b, Normalize(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{})) 436 } 437 } 438 439 func BenchmarkNormalizeTraces(b *testing.B) { 440 for _, trace := range []string{"django_queries.txt", "lobsters.sql.gz"} { 441 b.Run(trace, func(b *testing.B) { 442 queries := loadQueries(b, trace) 443 if len(queries) > 10000 { 444 queries = queries[:10000] 445 } 446 447 parsed := make([]Statement, 0, len(queries)) 448 reservedVars := make([]BindVars, 0, len(queries)) 449 for _, q := range queries { 450 pp, kb, err := Parse2(q) 451 if err != nil { 452 b.Fatal(err) 453 } 454 parsed = append(parsed, pp) 455 reservedVars = append(reservedVars, kb) 456 } 457 458 b.ResetTimer() 459 b.ReportAllocs() 460 461 for i := 0; i < b.N; i++ { 462 for i, query := range parsed { 463 _ = Normalize(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{}) 464 } 465 } 466 }) 467 } 468 } 469 470 func BenchmarkNormalizeVTGate(b *testing.B) { 471 const keyspace = "main_keyspace" 472 473 queries := loadQueries(b, "lobsters.sql.gz") 474 if len(queries) > 10000 { 475 queries = queries[:10000] 476 } 477 478 b.ResetTimer() 479 b.ReportAllocs() 480 481 for i := 0; i < b.N; i++ { 482 for _, sql := range queries { 483 stmt, reservedVars, err := Parse2(sql) 484 if err != nil { 485 b.Fatal(err) 486 } 487 488 query := sql 489 statement := stmt 490 bindVarNeeds := &BindVarNeeds{} 491 bindVars := make(map[string]*querypb.BindVariable) 492 _ = IgnoreMaxMaxMemoryRowsDirective(stmt) 493 494 // Normalize if possible and retry. 495 if CanNormalize(stmt) || MustRewriteAST(stmt, false) { 496 result, err := PrepareAST( 497 stmt, 498 NewReservedVars("vtg", reservedVars), 499 bindVars, 500 true, 501 keyspace, 502 SQLSelectLimitUnset, 503 "", 504 nil, /*sysvars*/ 505 nil, /*views*/ 506 ) 507 if err != nil { 508 b.Fatal(err) 509 } 510 statement = result.AST 511 bindVarNeeds = result.BindVarNeeds 512 query = String(statement) 513 } 514 515 _ = query 516 _ = statement 517 _ = bindVarNeeds 518 } 519 } 520 } 521 522 func randtmpl(template string) string { 523 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 524 const numberBytes = "0123456789" 525 526 result := []byte(template) 527 for i, c := range result { 528 switch c { 529 case '#': 530 result[i] = numberBytes[rand.Intn(len(numberBytes))] 531 case '@': 532 result[i] = letterBytes[rand.Intn(len(letterBytes))] 533 } 534 } 535 return string(result) 536 } 537 538 func randString(n int) string { 539 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 540 b := make([]byte, n) 541 for i := range b { 542 b[i] = letterBytes[rand.Intn(len(letterBytes))] 543 } 544 return string(b) 545 } 546 547 func BenchmarkNormalizeTPCCBinds(b *testing.B) { 548 query := `INSERT IGNORE INTO customer0 549 (c_id, c_d_id, c_w_id, c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_since, c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_payment_cnt, c_delivery_cnt, c_data) 550 values 551 (:c_id, :c_d_id, :c_w_id, :c_first, :c_middle, :c_last, :c_street_1, :c_street_2, :c_city, :c_state, :c_zip, :c_phone, :c_since, :c_credit, :c_credit_lim, :c_discount, :c_balance, :c_ytd_payment, :c_payment_cnt, :c_delivery_cnt, :c_data)` 552 benchmarkNormalization(b, []string{query}) 553 } 554 555 func BenchmarkNormalizeTPCCInsert(b *testing.B) { 556 generateInsert := func(rows int) string { 557 var query bytes.Buffer 558 query.WriteString("INSERT IGNORE INTO customer0 (c_id, c_d_id, c_w_id, c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_since, c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_payment_cnt, c_delivery_cnt, c_data) values ") 559 for i := 0; i < rows; i++ { 560 fmt.Fprintf(&query, "(%d, %d, %d, '%s','OE','%s','%s', '%s', '%s', '%s', '%s','%s',NOW(),'%s',50000,%f,-10,10,1,0,'%s' )", 561 rand.Int(), rand.Int(), rand.Int(), 562 "first-"+randString(rand.Intn(10)), 563 randtmpl("last-@@@@"), 564 randtmpl("street1-@@@@@@@@@@@@"), 565 randtmpl("street2-@@@@@@@@@@@@"), 566 randtmpl("city-@@@@@@@@@@@@"), 567 randtmpl("@@"), randtmpl("zip-#####"), 568 randtmpl("################"), 569 "GC", rand.Float64(), randString(300+rand.Intn(200)), 570 ) 571 if i < rows-1 { 572 query.WriteString(", ") 573 } 574 } 575 return query.String() 576 } 577 578 var queries []string 579 580 for i := 0; i < 1024; i++ { 581 queries = append(queries, generateInsert(4)) 582 } 583 584 benchmarkNormalization(b, queries) 585 } 586 587 func BenchmarkNormalizeTPCC(b *testing.B) { 588 templates := []string{ 589 `SELECT c_discount, c_last, c_credit, w_tax 590 FROM customer%d AS c 591 JOIN warehouse%d AS w ON c_w_id=w_id 592 WHERE w_id = %d 593 AND c_d_id = %d 594 AND c_id = %d`, 595 `SELECT d_next_o_id, d_tax 596 FROM district%d 597 WHERE d_w_id = %d 598 AND d_id = %d FOR UPDATE`, 599 `UPDATE district%d 600 SET d_next_o_id = %d 601 WHERE d_id = %d AND d_w_id= %d`, 602 `INSERT INTO orders%d 603 (o_id, o_d_id, o_w_id, o_c_id, o_entry_d, o_ol_cnt, o_all_local) 604 VALUES (%d,%d,%d,%d,NOW(),%d,%d)`, 605 `INSERT INTO new_orders%d (no_o_id, no_d_id, no_w_id) 606 VALUES (%d,%d,%d)`, 607 `SELECT i_price, i_name, i_data 608 FROM item%d 609 WHERE i_id = %d`, 610 `SELECT s_quantity, s_data, s_dist_%s s_dist 611 FROM stock%d 612 WHERE s_i_id = %d AND s_w_id= %d FOR UPDATE`, 613 `UPDATE stock%d 614 SET s_quantity = %d 615 WHERE s_i_id = %d 616 AND s_w_id= %d`, 617 `INSERT INTO order_line%d 618 (ol_o_id, ol_d_id, ol_w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info) 619 VALUES (%d,%d,%d,%d,%d,%d,%d,%d,'%s')`, 620 `UPDATE warehouse%d 621 SET w_ytd = w_ytd + %d 622 WHERE w_id = %d`, 623 `SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name 624 FROM warehouse%d 625 WHERE w_id = %d`, 626 `UPDATE district%d 627 SET d_ytd = d_ytd + %d 628 WHERE d_w_id = %d 629 AND d_id= %d`, 630 `SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name 631 FROM district%d 632 WHERE d_w_id = %d 633 AND d_id = %d`, 634 `SELECT count(c_id) namecnt 635 FROM customer%d 636 WHERE c_w_id = %d 637 AND c_d_id= %d 638 AND c_last='%s'`, 639 `SELECT c_first, c_middle, c_last, c_street_1, 640 c_street_2, c_city, c_state, c_zip, c_phone, 641 c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_since 642 FROM customer%d 643 WHERE c_w_id = %d 644 AND c_d_id= %d 645 AND c_id=%d FOR UPDATE`, 646 `SELECT c_data 647 FROM customer%d 648 WHERE c_w_id = %d 649 AND c_d_id=%d 650 AND c_id= %d`, 651 `UPDATE customer%d 652 SET c_balance=%f, c_ytd_payment=%f, c_data='%s' 653 WHERE c_w_id = %d 654 AND c_d_id=%d 655 AND c_id=%d`, 656 `UPDATE customer%d 657 SET c_balance=%f, c_ytd_payment=%f 658 WHERE c_w_id = %d 659 AND c_d_id=%d 660 AND c_id=%d`, 661 `INSERT INTO history%d 662 (h_c_d_id, h_c_w_id, h_c_id, h_d_id, h_w_id, h_date, h_amount, h_data) 663 VALUES (%d,%d,%d,%d,%d,NOW(),%d,'%s')`, 664 `SELECT count(c_id) namecnt 665 FROM customer%d 666 WHERE c_w_id = %d 667 AND c_d_id= %d 668 AND c_last='%s'`, 669 `SELECT c_balance, c_first, c_middle, c_id 670 FROM customer%d 671 WHERE c_w_id = %d 672 AND c_d_id= %d 673 AND c_last='%s' ORDER BY c_first`, 674 `SELECT c_balance, c_first, c_middle, c_last 675 FROM customer%d 676 WHERE c_w_id = %d 677 AND c_d_id=%d 678 AND c_id=%d`, 679 `SELECT o_id, o_carrier_id, o_entry_d 680 FROM orders%d 681 WHERE o_w_id = %d 682 AND o_d_id = %d 683 AND o_c_id = %d 684 ORDER BY o_id DESC`, 685 `SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d 686 FROM order_line%d WHERE ol_w_id = %d AND ol_d_id = %d AND ol_o_id = %d`, 687 `SELECT no_o_id 688 FROM new_orders%d 689 WHERE no_d_id = %d 690 AND no_w_id = %d 691 ORDER BY no_o_id ASC LIMIT 1 FOR UPDATE`, 692 `DELETE FROM new_orders%d 693 WHERE no_o_id = %d 694 AND no_d_id = %d 695 AND no_w_id = %d`, 696 `SELECT o_c_id 697 FROM orders%d 698 WHERE o_id = %d 699 AND o_d_id = %d 700 AND o_w_id = %d`, 701 `UPDATE orders%d 702 SET o_carrier_id = %d 703 WHERE o_id = %d 704 AND o_d_id = %d 705 AND o_w_id = %d`, 706 `UPDATE order_line%d 707 SET ol_delivery_d = NOW() 708 WHERE ol_o_id = %d 709 AND ol_d_id = %d 710 AND ol_w_id = %d`, 711 `SELECT SUM(ol_amount) sm 712 FROM order_line%d 713 WHERE ol_o_id = %d 714 AND ol_d_id = %d 715 AND ol_w_id = %d`, 716 `UPDATE customer%d 717 SET c_balance = c_balance + %f, 718 c_delivery_cnt = c_delivery_cnt + 1 719 WHERE c_id = %d 720 AND c_d_id = %d 721 AND c_w_id = %d`, 722 `SELECT d_next_o_id 723 FROM district%d 724 WHERE d_id = %d AND d_w_id= %d`, 725 `SELECT COUNT(DISTINCT(s.s_i_id)) 726 FROM stock%d AS s 727 JOIN order_line%d AS ol ON ol.ol_w_id=s.s_w_id AND ol.ol_i_id=s.s_i_id 728 WHERE ol.ol_w_id = %d 729 AND ol.ol_d_id = %d 730 AND ol.ol_o_id < %d 731 AND ol.ol_o_id >= %d 732 AND s.s_w_id= %d 733 AND s.s_quantity < %d `, 734 `SELECT DISTINCT ol_i_id FROM order_line%d 735 WHERE ol_w_id = %d AND ol_d_id = %d 736 AND ol_o_id < %d AND ol_o_id >= %d`, 737 `SELECT count(*) FROM stock%d 738 WHERE s_w_id = %d AND s_i_id = %d 739 AND s_quantity < %d`, 740 `SELECT min(no_o_id) mo 741 FROM new_orders%d 742 WHERE no_w_id = %d AND no_d_id = %d`, 743 `SELECT o_id FROM orders%d o, (SELECT o_c_id,o_w_id,o_d_id,count(distinct o_id) FROM orders%d WHERE o_w_id=%d AND o_d_id=%d AND o_id > 2100 AND o_id < %d GROUP BY o_c_id,o_d_id,o_w_id having count( distinct o_id) > 1 limit 1) t WHERE t.o_w_id=o.o_w_id and t.o_d_id=o.o_d_id and t.o_c_id=o.o_c_id limit 1 `, 744 `DELETE FROM order_line%d where ol_w_id=%d AND ol_d_id=%d AND ol_o_id=%d`, 745 `DELETE FROM orders%d where o_w_id=%d AND o_d_id=%d and o_id=%d`, 746 `DELETE FROM history%d where h_w_id=%d AND h_d_id=%d LIMIT 10`, 747 } 748 749 re := regexp.MustCompile(`%\w`) 750 repl := func(m string) string { 751 switch m { 752 case "%s": 753 return "RANDOM_STRING" 754 case "%d": 755 return strconv.Itoa(rand.Int()) 756 case "%f": 757 return strconv.FormatFloat(rand.Float64(), 'f', 8, 64) 758 default: 759 panic(m) 760 } 761 } 762 763 var queries []string 764 765 for _, tmpl := range templates { 766 for i := 0; i < 128; i++ { 767 queries = append(queries, re.ReplaceAllStringFunc(tmpl, repl)) 768 } 769 } 770 771 benchmarkNormalization(b, queries) 772 } 773 774 func benchmarkNormalization(b *testing.B, sqls []string) { 775 b.Helper() 776 b.ReportAllocs() 777 b.ResetTimer() 778 for i := 0; i < b.N; i++ { 779 for _, sql := range sqls { 780 stmt, reserved, err := Parse2(sql) 781 if err != nil { 782 b.Fatalf("%v: %q", err, sql) 783 } 784 785 reservedVars := NewReservedVars("vtg", reserved) 786 _, err = PrepareAST( 787 stmt, 788 reservedVars, 789 make(map[string]*querypb.BindVariable), 790 true, 791 "keyspace0", 792 SQLSelectLimitUnset, 793 "", 794 nil, 795 nil, 796 ) 797 if err != nil { 798 b.Fatal(err) 799 } 800 } 801 } 802 }