vitess.io/vitess@v0.16.2/go/vt/sqlparser/ast_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "bytes" 21 "encoding/json" 22 "reflect" 23 "strings" 24 "testing" 25 "unsafe" 26 27 "github.com/stretchr/testify/assert" 28 29 "github.com/stretchr/testify/require" 30 ) 31 32 func TestAppend(t *testing.T) { 33 query := "select * from t where a = 1" 34 tree, err := Parse(query) 35 require.NoError(t, err) 36 var b strings.Builder 37 Append(&b, tree) 38 got := b.String() 39 want := query 40 if got != want { 41 t.Errorf("Append: %s, want %s", got, want) 42 } 43 Append(&b, tree) 44 got = b.String() 45 want = query + query 46 if got != want { 47 t.Errorf("Append: %s, want %s", got, want) 48 } 49 } 50 51 func TestSelect(t *testing.T) { 52 tree, err := Parse("select * from t where a = 1") 53 require.NoError(t, err) 54 expr := tree.(*Select).Where.Expr 55 56 sel := &Select{} 57 sel.AddWhere(expr) 58 buf := NewTrackedBuffer(nil) 59 sel.Where.Format(buf) 60 assert.Equal(t, " where a = 1", buf.String()) 61 sel.AddWhere(expr) 62 buf = NewTrackedBuffer(nil) 63 sel.Where.Format(buf) 64 assert.Equal(t, " where a = 1", buf.String()) 65 66 sel = &Select{} 67 sel.AddHaving(expr) 68 buf = NewTrackedBuffer(nil) 69 sel.Having.Format(buf) 70 assert.Equal(t, " having a = 1", buf.String()) 71 72 sel.AddHaving(expr) 73 buf = NewTrackedBuffer(nil) 74 sel.Having.Format(buf) 75 assert.Equal(t, " having a = 1", buf.String()) 76 77 tree, err = Parse("select * from t where a = 1 or b = 1") 78 require.NoError(t, err) 79 expr = tree.(*Select).Where.Expr 80 sel = &Select{} 81 sel.AddWhere(expr) 82 buf = NewTrackedBuffer(nil) 83 sel.Where.Format(buf) 84 assert.Equal(t, " where a = 1 or b = 1", buf.String()) 85 86 sel = &Select{} 87 sel.AddHaving(expr) 88 buf = NewTrackedBuffer(nil) 89 sel.Having.Format(buf) 90 assert.Equal(t, " having a = 1 or b = 1", buf.String()) 91 92 } 93 94 func TestUpdate(t *testing.T) { 95 tree, err := Parse("update t set a = 1") 96 require.NoError(t, err) 97 98 upd, ok := tree.(*Update) 99 require.True(t, ok) 100 101 upd.AddWhere(&ComparisonExpr{ 102 Left: &ColName{Name: NewIdentifierCI("b")}, 103 Operator: EqualOp, 104 Right: NewIntLiteral("2"), 105 }) 106 assert.Equal(t, "update t set a = 1 where b = 2", String(upd)) 107 108 upd.AddWhere(&ComparisonExpr{ 109 Left: &ColName{Name: NewIdentifierCI("c")}, 110 Operator: EqualOp, 111 Right: NewIntLiteral("3"), 112 }) 113 assert.Equal(t, "update t set a = 1 where b = 2 and c = 3", String(upd)) 114 } 115 116 func TestRemoveHints(t *testing.T) { 117 for _, query := range []string{ 118 "select * from t use index (i)", 119 "select * from t force index (i)", 120 } { 121 tree, err := Parse(query) 122 if err != nil { 123 t.Fatal(err) 124 } 125 sel := tree.(*Select) 126 sel.From = TableExprs{ 127 sel.From[0].(*AliasedTableExpr).RemoveHints(), 128 } 129 buf := NewTrackedBuffer(nil) 130 sel.Format(buf) 131 if got, want := buf.String(), "select * from t"; got != want { 132 t.Errorf("stripped query: %s, want %s", got, want) 133 } 134 } 135 } 136 137 func TestAddOrder(t *testing.T) { 138 src, err := Parse("select foo, bar from baz order by foo") 139 require.NoError(t, err) 140 order := src.(*Select).OrderBy[0] 141 dst, err := Parse("select * from t") 142 require.NoError(t, err) 143 dst.(*Select).AddOrder(order) 144 buf := NewTrackedBuffer(nil) 145 dst.Format(buf) 146 require.Equal(t, "select * from t order by foo asc", buf.String()) 147 dst, err = Parse("select * from t union select * from s") 148 require.NoError(t, err) 149 dst.(*Union).AddOrder(order) 150 buf = NewTrackedBuffer(nil) 151 dst.Format(buf) 152 require.Equal(t, "select * from t union select * from s order by foo asc", buf.String()) 153 } 154 155 func TestSetLimit(t *testing.T) { 156 src, err := Parse("select foo, bar from baz limit 4") 157 require.NoError(t, err) 158 limit := src.(*Select).Limit 159 dst, err := Parse("select * from t") 160 require.NoError(t, err) 161 dst.(*Select).SetLimit(limit) 162 buf := NewTrackedBuffer(nil) 163 dst.Format(buf) 164 require.Equal(t, "select * from t limit 4", buf.String()) 165 dst, err = Parse("select * from t union select * from s") 166 require.NoError(t, err) 167 dst.(*Union).SetLimit(limit) 168 buf = NewTrackedBuffer(nil) 169 dst.Format(buf) 170 require.Equal(t, "select * from t union select * from s limit 4", buf.String()) 171 } 172 173 func TestDDL(t *testing.T) { 174 testcases := []struct { 175 query string 176 output DDLStatement 177 affected []string 178 }{{ 179 query: "create table a", 180 output: &CreateTable{ 181 Table: TableName{Name: NewIdentifierCS("a")}, 182 }, 183 affected: []string{"a"}, 184 }, { 185 query: "rename table a to b", 186 output: &RenameTable{ 187 TablePairs: []*RenameTablePair{ 188 { 189 FromTable: TableName{Name: NewIdentifierCS("a")}, 190 ToTable: TableName{Name: NewIdentifierCS("b")}, 191 }, 192 }, 193 }, 194 affected: []string{"a", "b"}, 195 }, { 196 query: "rename table a to b, c to d", 197 output: &RenameTable{ 198 TablePairs: []*RenameTablePair{ 199 { 200 FromTable: TableName{Name: NewIdentifierCS("a")}, 201 ToTable: TableName{Name: NewIdentifierCS("b")}, 202 }, { 203 FromTable: TableName{Name: NewIdentifierCS("c")}, 204 ToTable: TableName{Name: NewIdentifierCS("d")}, 205 }, 206 }, 207 }, 208 affected: []string{"a", "b", "c", "d"}, 209 }, { 210 query: "drop table a", 211 output: &DropTable{ 212 FromTables: TableNames{ 213 TableName{Name: NewIdentifierCS("a")}, 214 }, 215 }, 216 affected: []string{"a"}, 217 }, { 218 query: "drop table a, b", 219 output: &DropTable{ 220 FromTables: TableNames{ 221 TableName{Name: NewIdentifierCS("a")}, 222 TableName{Name: NewIdentifierCS("b")}, 223 }, 224 }, 225 affected: []string{"a", "b"}, 226 }} 227 for _, tcase := range testcases { 228 got, err := Parse(tcase.query) 229 if err != nil { 230 t.Fatal(err) 231 } 232 if !reflect.DeepEqual(got, tcase.output) { 233 t.Errorf("%s: %v, want %v", tcase.query, got, tcase.output) 234 } 235 want := make(TableNames, 0, len(tcase.affected)) 236 for _, t := range tcase.affected { 237 want = append(want, TableName{Name: NewIdentifierCS(t)}) 238 } 239 if affected := got.(DDLStatement).AffectedTables(); !reflect.DeepEqual(affected, want) { 240 t.Errorf("Affected(%s): %v, want %v", tcase.query, affected, want) 241 } 242 } 243 } 244 245 func TestSetAutocommitON(t *testing.T) { 246 stmt, err := Parse("SET autocommit=ON") 247 require.NoError(t, err) 248 s, ok := stmt.(*Set) 249 if !ok { 250 t.Errorf("SET statement is not Set: %T", s) 251 } 252 253 if len(s.Exprs) < 1 { 254 t.Errorf("SET statement has no expressions") 255 } 256 257 e := s.Exprs[0] 258 switch v := e.Expr.(type) { 259 case *Literal: 260 if v.Type != StrVal { 261 t.Errorf("SET statement value is not StrVal: %T", v) 262 } 263 264 if "on" != v.Val { 265 t.Errorf("SET statement value want: on, got: %s", v.Val) 266 } 267 default: 268 t.Errorf("SET statement expression is not Literal: %T", e.Expr) 269 } 270 271 stmt, err = Parse("SET @@session.autocommit=ON") 272 require.NoError(t, err) 273 s, ok = stmt.(*Set) 274 if !ok { 275 t.Errorf("SET statement is not Set: %T", s) 276 } 277 278 if len(s.Exprs) < 1 { 279 t.Errorf("SET statement has no expressions") 280 } 281 282 e = s.Exprs[0] 283 switch v := e.Expr.(type) { 284 case *Literal: 285 if v.Type != StrVal { 286 t.Errorf("SET statement value is not StrVal: %T", v) 287 } 288 289 if "on" != v.Val { 290 t.Errorf("SET statement value want: on, got: %s", v.Val) 291 } 292 default: 293 t.Errorf("SET statement expression is not Literal: %T", e.Expr) 294 } 295 } 296 297 func TestSetAutocommitOFF(t *testing.T) { 298 stmt, err := Parse("SET autocommit=OFF") 299 require.NoError(t, err) 300 s, ok := stmt.(*Set) 301 if !ok { 302 t.Errorf("SET statement is not Set: %T", s) 303 } 304 305 if len(s.Exprs) < 1 { 306 t.Errorf("SET statement has no expressions") 307 } 308 309 e := s.Exprs[0] 310 switch v := e.Expr.(type) { 311 case *Literal: 312 if v.Type != StrVal { 313 t.Errorf("SET statement value is not StrVal: %T", v) 314 } 315 316 if "off" != v.Val { 317 t.Errorf("SET statement value want: on, got: %s", v.Val) 318 } 319 default: 320 t.Errorf("SET statement expression is not Literal: %T", e.Expr) 321 } 322 323 stmt, err = Parse("SET @@session.autocommit=OFF") 324 require.NoError(t, err) 325 s, ok = stmt.(*Set) 326 if !ok { 327 t.Errorf("SET statement is not Set: %T", s) 328 } 329 330 if len(s.Exprs) < 1 { 331 t.Errorf("SET statement has no expressions") 332 } 333 334 e = s.Exprs[0] 335 switch v := e.Expr.(type) { 336 case *Literal: 337 if v.Type != StrVal { 338 t.Errorf("SET statement value is not StrVal: %T", v) 339 } 340 341 if "off" != v.Val { 342 t.Errorf("SET statement value want: on, got: %s", v.Val) 343 } 344 default: 345 t.Errorf("SET statement expression is not Literal: %T", e.Expr) 346 } 347 348 } 349 350 func TestWhere(t *testing.T) { 351 var w *Where 352 buf := NewTrackedBuffer(nil) 353 w.Format(buf) 354 if buf.String() != "" { 355 t.Errorf("w.Format(nil): %q, want \"\"", buf.String()) 356 } 357 w = NewWhere(WhereClause, nil) 358 buf = NewTrackedBuffer(nil) 359 w.Format(buf) 360 if buf.String() != "" { 361 t.Errorf("w.Format(&Where{nil}: %q, want \"\"", buf.String()) 362 } 363 } 364 365 func TestIsAggregate(t *testing.T) { 366 f := FuncExpr{Name: NewIdentifierCI("avg")} 367 if !f.IsAggregate() { 368 t.Error("IsAggregate: false, want true") 369 } 370 371 f = FuncExpr{Name: NewIdentifierCI("Avg")} 372 if !f.IsAggregate() { 373 t.Error("IsAggregate: false, want true") 374 } 375 376 f = FuncExpr{Name: NewIdentifierCI("foo")} 377 if f.IsAggregate() { 378 t.Error("IsAggregate: true, want false") 379 } 380 } 381 382 func TestIsImpossible(t *testing.T) { 383 f := ComparisonExpr{ 384 Operator: NotEqualOp, 385 Left: NewIntLiteral("1"), 386 Right: NewIntLiteral("1"), 387 } 388 if !f.IsImpossible() { 389 t.Error("IsImpossible: false, want true") 390 } 391 392 f = ComparisonExpr{ 393 Operator: EqualOp, 394 Left: NewIntLiteral("1"), 395 Right: NewIntLiteral("1"), 396 } 397 if f.IsImpossible() { 398 t.Error("IsImpossible: true, want false") 399 } 400 401 f = ComparisonExpr{ 402 Operator: NotEqualOp, 403 Left: NewIntLiteral("1"), 404 Right: NewIntLiteral("2"), 405 } 406 if f.IsImpossible() { 407 t.Error("IsImpossible: true, want false") 408 } 409 } 410 411 func TestReplaceExpr(t *testing.T) { 412 tcases := []struct { 413 in, out string 414 }{{ 415 in: "select * from t where (select a from b)", 416 out: ":a", 417 }, { 418 in: "select * from t where (select a from b) and b", 419 out: ":a and b", 420 }, { 421 in: "select * from t where a and (select a from b)", 422 out: "a and :a", 423 }, { 424 in: "select * from t where (select a from b) or b", 425 out: ":a or b", 426 }, { 427 in: "select * from t where a or (select a from b)", 428 out: "a or :a", 429 }, { 430 in: "select * from t where not (select a from b)", 431 out: "not :a", 432 }, { 433 in: "select * from t where ((select a from b))", 434 out: ":a", 435 }, { 436 in: "select * from t where (select a from b) = 1", 437 out: ":a = 1", 438 }, { 439 in: "select * from t where a = (select a from b)", 440 out: "a = :a", 441 }, { 442 in: "select * from t where a like b escape (select a from b)", 443 out: "a like b escape :a", 444 }, { 445 in: "select * from t where (select a from b) between a and b", 446 out: ":a between a and b", 447 }, { 448 in: "select * from t where a between (select a from b) and b", 449 out: "a between :a and b", 450 }, { 451 in: "select * from t where a between b and (select a from b)", 452 out: "a between b and :a", 453 }, { 454 in: "select * from t where (select a from b) is null", 455 out: ":a is null", 456 }, { 457 // exists should not replace. 458 in: "select * from t where exists (select a from b)", 459 out: "exists (select a from b)", 460 }, { 461 in: "select * from t where a in ((select a from b), 1)", 462 out: "a in (:a, 1)", 463 }, { 464 in: "select * from t where a in (0, (select a from b), 1)", 465 out: "a in (0, :a, 1)", 466 }, { 467 in: "select * from t where (select a from b) + 1", 468 out: ":a + 1", 469 }, { 470 in: "select * from t where 1+(select a from b)", 471 out: "1 + :a", 472 }, { 473 in: "select * from t where -(select a from b)", 474 out: "-:a", 475 }, { 476 in: "select * from t where interval (select a from b) aa", 477 out: "interval :a aa", 478 }, { 479 in: "select * from t where (select a from b) collate utf8", 480 out: ":a collate utf8", 481 }, { 482 in: "select * from t where func((select a from b), 1)", 483 out: "func(:a, 1)", 484 }, { 485 in: "select * from t where func(1, (select a from b), 1)", 486 out: "func(1, :a, 1)", 487 }, { 488 in: "select * from t where group_concat((select a from b), 1 order by a)", 489 out: "group_concat(:a, 1 order by a asc)", 490 }, { 491 in: "select * from t where group_concat(1 order by (select a from b), a)", 492 out: "group_concat(1 order by :a asc, a asc)", 493 }, { 494 in: "select * from t where group_concat(1 order by a, (select a from b))", 495 out: "group_concat(1 order by a asc, :a asc)", 496 }, { 497 in: "select * from t where substr(a, (select a from b), b)", 498 out: "substr(a, :a, b)", 499 }, { 500 in: "select * from t where substr(a, b, (select a from b))", 501 out: "substr(a, b, :a)", 502 }, { 503 in: "select * from t where convert((select a from b), json)", 504 out: "convert(:a, json)", 505 }, { 506 in: "select * from t where convert((select a from b) using utf8)", 507 out: "convert(:a using utf8)", 508 }, { 509 in: "select * from t where case (select a from b) when a then b when b then c else d end", 510 out: "case :a when a then b when b then c else d end", 511 }, { 512 in: "select * from t where case a when (select a from b) then b when b then c else d end", 513 out: "case a when :a then b when b then c else d end", 514 }, { 515 in: "select * from t where case a when b then (select a from b) when b then c else d end", 516 out: "case a when b then :a when b then c else d end", 517 }, { 518 in: "select * from t where case a when b then c when (select a from b) then c else d end", 519 out: "case a when b then c when :a then c else d end", 520 }, { 521 in: "select * from t where case a when b then c when d then c else (select a from b) end", 522 out: "case a when b then c when d then c else :a end", 523 }} 524 to := NewArgument("a") 525 for _, tcase := range tcases { 526 tree, err := Parse(tcase.in) 527 if err != nil { 528 t.Fatal(err) 529 } 530 var from *Subquery 531 _ = Walk(func(node SQLNode) (kontinue bool, err error) { 532 if sq, ok := node.(*Subquery); ok { 533 from = sq 534 return false, nil 535 } 536 return true, nil 537 }, tree) 538 if from == nil { 539 t.Fatalf("from is nil for %s", tcase.in) 540 } 541 expr := ReplaceExpr(tree.(*Select).Where.Expr, from, to) 542 got := String(expr) 543 if tcase.out != got { 544 t.Errorf("ReplaceExpr(%s): %s, want %s", tcase.in, got, tcase.out) 545 } 546 } 547 } 548 549 func TestColNameEqual(t *testing.T) { 550 var c1, c2 *ColName 551 if c1.Equal(c2) { 552 t.Error("nil columns equal, want unequal") 553 } 554 c1 = &ColName{ 555 Name: NewIdentifierCI("aa"), 556 } 557 c2 = &ColName{ 558 Name: NewIdentifierCI("bb"), 559 } 560 if c1.Equal(c2) { 561 t.Error("columns equal, want unequal") 562 } 563 c2.Name = NewIdentifierCI("aa") 564 if !c1.Equal(c2) { 565 t.Error("columns unequal, want equal") 566 } 567 } 568 569 func TestIdentifierCI(t *testing.T) { 570 str := NewIdentifierCI("Ab") 571 if str.String() != "Ab" { 572 t.Errorf("String=%s, want Ab", str.String()) 573 } 574 if str.String() != "Ab" { 575 t.Errorf("Val=%s, want Ab", str.String()) 576 } 577 if str.Lowered() != "ab" { 578 t.Errorf("Val=%s, want ab", str.Lowered()) 579 } 580 if !str.Equal(NewIdentifierCI("aB")) { 581 t.Error("str.Equal(NewIdentifierCI(aB))=false, want true") 582 } 583 if !str.EqualString("ab") { 584 t.Error("str.EqualString(ab)=false, want true") 585 } 586 str = NewIdentifierCI("") 587 if str.Lowered() != "" { 588 t.Errorf("Val=%s, want \"\"", str.Lowered()) 589 } 590 } 591 592 func TestIdentifierCIMarshal(t *testing.T) { 593 str := NewIdentifierCI("Ab") 594 b, err := json.Marshal(str) 595 if err != nil { 596 t.Fatal(err) 597 } 598 got := string(b) 599 want := `"Ab"` 600 if got != want { 601 t.Errorf("json.Marshal()= %s, want %s", got, want) 602 } 603 var out IdentifierCI 604 if err := json.Unmarshal(b, &out); err != nil { 605 t.Errorf("Unmarshal err: %v, want nil", err) 606 } 607 if !reflect.DeepEqual(out, str) { 608 t.Errorf("Unmarshal: %v, want %v", out, str) 609 } 610 } 611 612 func TestIdentifierCISize(t *testing.T) { 613 size := unsafe.Sizeof(NewIdentifierCI("")) 614 want := 2 * unsafe.Sizeof("") 615 assert.Equal(t, want, size, "size of IdentifierCI") 616 } 617 618 func TestIdentifierCSMarshal(t *testing.T) { 619 str := NewIdentifierCS("Ab") 620 b, err := json.Marshal(str) 621 if err != nil { 622 t.Fatal(err) 623 } 624 got := string(b) 625 want := `"Ab"` 626 if got != want { 627 t.Errorf("json.Marshal()= %s, want %s", got, want) 628 } 629 var out IdentifierCS 630 if err := json.Unmarshal(b, &out); err != nil { 631 t.Errorf("Unmarshal err: %v, want nil", err) 632 } 633 if !reflect.DeepEqual(out, str) { 634 t.Errorf("Unmarshal: %v, want %v", out, str) 635 } 636 } 637 638 func TestHexDecode(t *testing.T) { 639 testcase := []struct { 640 in, out string 641 }{{ 642 in: "313233", 643 out: "123", 644 }, { 645 in: "ag", 646 out: "encoding/hex: invalid byte: U+0067 'g'", 647 }, { 648 in: "777", 649 out: "encoding/hex: odd length hex string", 650 }} 651 for _, tc := range testcase { 652 out, err := NewHexLiteral(tc.in).HexDecode() 653 if err != nil { 654 if err.Error() != tc.out { 655 t.Errorf("Decode(%q): %v, want %s", tc.in, err, tc.out) 656 } 657 continue 658 } 659 if !bytes.Equal(out, []byte(tc.out)) { 660 t.Errorf("Decode(%q): %s, want %s", tc.in, out, tc.out) 661 } 662 } 663 } 664 665 func TestCompliantName(t *testing.T) { 666 testcases := []struct { 667 in, out string 668 }{{ 669 in: "aa", 670 out: "aa", 671 }, { 672 in: "1a", 673 out: "_a", 674 }, { 675 in: "a1", 676 out: "a1", 677 }, { 678 in: "a.b", 679 out: "a_b", 680 }, { 681 in: ".ab", 682 out: "_ab", 683 }} 684 for _, tc := range testcases { 685 out := NewIdentifierCI(tc.in).CompliantName() 686 if out != tc.out { 687 t.Errorf("IdentifierCI(%s).CompliantNamt: %s, want %s", tc.in, out, tc.out) 688 } 689 out = NewIdentifierCS(tc.in).CompliantName() 690 if out != tc.out { 691 t.Errorf("IdentifierCS(%s).CompliantNamt: %s, want %s", tc.in, out, tc.out) 692 } 693 } 694 } 695 696 func TestColumns_FindColumn(t *testing.T) { 697 cols := Columns{NewIdentifierCI("a"), NewIdentifierCI("c"), NewIdentifierCI("b"), NewIdentifierCI("0")} 698 699 testcases := []struct { 700 in string 701 out int 702 }{{ 703 in: "a", 704 out: 0, 705 }, { 706 in: "b", 707 out: 2, 708 }, 709 { 710 in: "0", 711 out: 3, 712 }, 713 { 714 in: "f", 715 out: -1, 716 }} 717 718 for _, tc := range testcases { 719 val := cols.FindColumn(NewIdentifierCI(tc.in)) 720 if val != tc.out { 721 t.Errorf("FindColumn(%s): %d, want %d", tc.in, val, tc.out) 722 } 723 } 724 } 725 726 func TestSplitStatementToPieces(t *testing.T) { 727 testcases := []struct { 728 input string 729 output string 730 }{{ 731 input: "select * from table1; \t; \n; \n\t\t ;select * from table1;", 732 output: "select * from table1;select * from table1", 733 }, { 734 input: "select * from table", 735 }, { 736 input: "select * from table;", 737 output: "select * from table", 738 }, { 739 input: "select * from table; ", 740 output: "select * from table", 741 }, { 742 input: "select * from table1; select * from table2;", 743 output: "select * from table1; select * from table2", 744 }, { 745 input: "select * from /* comment ; */ table;", 746 output: "select * from /* comment ; */ table", 747 }, { 748 input: "select * from table where semi = ';';", 749 output: "select * from table where semi = ';'", 750 }, { 751 input: "select * from table1;--comment;\nselect * from table2;", 752 output: "select * from table1;--comment;\nselect * from table2", 753 }, { 754 input: "CREATE TABLE `total_data` (`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'id', " + 755 "`region` varchar(32) NOT NULL COMMENT 'region name, like zh; th; kepler'," + 756 "`data_size` bigint NOT NULL DEFAULT '0' COMMENT 'data size;'," + 757 "`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," + 758 "`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," + 759 "PRIMARY KEY (`id`))", 760 }} 761 762 for _, tcase := range testcases { 763 t.Run(tcase.input, func(t *testing.T) { 764 if tcase.output == "" { 765 tcase.output = tcase.input 766 } 767 768 stmtPieces, err := SplitStatementToPieces(tcase.input) 769 require.NoError(t, err) 770 771 out := strings.Join(stmtPieces, ";") 772 require.Equal(t, tcase.output, out) 773 }) 774 } 775 } 776 777 func TestTypeConversion(t *testing.T) { 778 ct1 := &ColumnType{Type: "BIGINT"} 779 ct2 := &ColumnType{Type: "bigint"} 780 assert.Equal(t, ct1.SQLType(), ct2.SQLType()) 781 } 782 783 func TestDefaultStatus(t *testing.T) { 784 assert.Equal(t, 785 String(&Default{ColName: "status"}), 786 "default(`status`)") 787 } 788 789 func TestShowTableStatus(t *testing.T) { 790 query := "Show Table Status FROM customer" 791 tree, err := Parse(query) 792 require.NoError(t, err) 793 require.NotNil(t, tree) 794 } 795 796 func BenchmarkStringTraces(b *testing.B) { 797 for _, trace := range []string{"django_queries.txt", "lobsters.sql.gz"} { 798 b.Run(trace, func(b *testing.B) { 799 queries := loadQueries(b, trace) 800 if len(queries) > 10000 { 801 queries = queries[:10000] 802 } 803 804 parsed := make([]Statement, 0, len(queries)) 805 for _, q := range queries { 806 pp, err := Parse(q) 807 if err != nil { 808 b.Fatal(err) 809 } 810 parsed = append(parsed, pp) 811 } 812 813 b.ResetTimer() 814 b.ReportAllocs() 815 816 for i := 0; i < b.N; i++ { 817 for _, stmt := range parsed { 818 _ = String(stmt) 819 } 820 } 821 }) 822 } 823 }