vitess.io/vitess@v0.16.2/go/vt/sqlparser/ast_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/json"
    22  	"reflect"
    23  	"strings"
    24  	"testing"
    25  	"unsafe"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func TestAppend(t *testing.T) {
    33  	query := "select * from t where a = 1"
    34  	tree, err := Parse(query)
    35  	require.NoError(t, err)
    36  	var b strings.Builder
    37  	Append(&b, tree)
    38  	got := b.String()
    39  	want := query
    40  	if got != want {
    41  		t.Errorf("Append: %s, want %s", got, want)
    42  	}
    43  	Append(&b, tree)
    44  	got = b.String()
    45  	want = query + query
    46  	if got != want {
    47  		t.Errorf("Append: %s, want %s", got, want)
    48  	}
    49  }
    50  
    51  func TestSelect(t *testing.T) {
    52  	tree, err := Parse("select * from t where a = 1")
    53  	require.NoError(t, err)
    54  	expr := tree.(*Select).Where.Expr
    55  
    56  	sel := &Select{}
    57  	sel.AddWhere(expr)
    58  	buf := NewTrackedBuffer(nil)
    59  	sel.Where.Format(buf)
    60  	assert.Equal(t, " where a = 1", buf.String())
    61  	sel.AddWhere(expr)
    62  	buf = NewTrackedBuffer(nil)
    63  	sel.Where.Format(buf)
    64  	assert.Equal(t, " where a = 1", buf.String())
    65  
    66  	sel = &Select{}
    67  	sel.AddHaving(expr)
    68  	buf = NewTrackedBuffer(nil)
    69  	sel.Having.Format(buf)
    70  	assert.Equal(t, " having a = 1", buf.String())
    71  
    72  	sel.AddHaving(expr)
    73  	buf = NewTrackedBuffer(nil)
    74  	sel.Having.Format(buf)
    75  	assert.Equal(t, " having a = 1", buf.String())
    76  
    77  	tree, err = Parse("select * from t where a = 1 or b = 1")
    78  	require.NoError(t, err)
    79  	expr = tree.(*Select).Where.Expr
    80  	sel = &Select{}
    81  	sel.AddWhere(expr)
    82  	buf = NewTrackedBuffer(nil)
    83  	sel.Where.Format(buf)
    84  	assert.Equal(t, " where a = 1 or b = 1", buf.String())
    85  
    86  	sel = &Select{}
    87  	sel.AddHaving(expr)
    88  	buf = NewTrackedBuffer(nil)
    89  	sel.Having.Format(buf)
    90  	assert.Equal(t, " having a = 1 or b = 1", buf.String())
    91  
    92  }
    93  
    94  func TestUpdate(t *testing.T) {
    95  	tree, err := Parse("update t set a = 1")
    96  	require.NoError(t, err)
    97  
    98  	upd, ok := tree.(*Update)
    99  	require.True(t, ok)
   100  
   101  	upd.AddWhere(&ComparisonExpr{
   102  		Left:     &ColName{Name: NewIdentifierCI("b")},
   103  		Operator: EqualOp,
   104  		Right:    NewIntLiteral("2"),
   105  	})
   106  	assert.Equal(t, "update t set a = 1 where b = 2", String(upd))
   107  
   108  	upd.AddWhere(&ComparisonExpr{
   109  		Left:     &ColName{Name: NewIdentifierCI("c")},
   110  		Operator: EqualOp,
   111  		Right:    NewIntLiteral("3"),
   112  	})
   113  	assert.Equal(t, "update t set a = 1 where b = 2 and c = 3", String(upd))
   114  }
   115  
   116  func TestRemoveHints(t *testing.T) {
   117  	for _, query := range []string{
   118  		"select * from t use index (i)",
   119  		"select * from t force index (i)",
   120  	} {
   121  		tree, err := Parse(query)
   122  		if err != nil {
   123  			t.Fatal(err)
   124  		}
   125  		sel := tree.(*Select)
   126  		sel.From = TableExprs{
   127  			sel.From[0].(*AliasedTableExpr).RemoveHints(),
   128  		}
   129  		buf := NewTrackedBuffer(nil)
   130  		sel.Format(buf)
   131  		if got, want := buf.String(), "select * from t"; got != want {
   132  			t.Errorf("stripped query: %s, want %s", got, want)
   133  		}
   134  	}
   135  }
   136  
   137  func TestAddOrder(t *testing.T) {
   138  	src, err := Parse("select foo, bar from baz order by foo")
   139  	require.NoError(t, err)
   140  	order := src.(*Select).OrderBy[0]
   141  	dst, err := Parse("select * from t")
   142  	require.NoError(t, err)
   143  	dst.(*Select).AddOrder(order)
   144  	buf := NewTrackedBuffer(nil)
   145  	dst.Format(buf)
   146  	require.Equal(t, "select * from t order by foo asc", buf.String())
   147  	dst, err = Parse("select * from t union select * from s")
   148  	require.NoError(t, err)
   149  	dst.(*Union).AddOrder(order)
   150  	buf = NewTrackedBuffer(nil)
   151  	dst.Format(buf)
   152  	require.Equal(t, "select * from t union select * from s order by foo asc", buf.String())
   153  }
   154  
   155  func TestSetLimit(t *testing.T) {
   156  	src, err := Parse("select foo, bar from baz limit 4")
   157  	require.NoError(t, err)
   158  	limit := src.(*Select).Limit
   159  	dst, err := Parse("select * from t")
   160  	require.NoError(t, err)
   161  	dst.(*Select).SetLimit(limit)
   162  	buf := NewTrackedBuffer(nil)
   163  	dst.Format(buf)
   164  	require.Equal(t, "select * from t limit 4", buf.String())
   165  	dst, err = Parse("select * from t union select * from s")
   166  	require.NoError(t, err)
   167  	dst.(*Union).SetLimit(limit)
   168  	buf = NewTrackedBuffer(nil)
   169  	dst.Format(buf)
   170  	require.Equal(t, "select * from t union select * from s limit 4", buf.String())
   171  }
   172  
   173  func TestDDL(t *testing.T) {
   174  	testcases := []struct {
   175  		query    string
   176  		output   DDLStatement
   177  		affected []string
   178  	}{{
   179  		query: "create table a",
   180  		output: &CreateTable{
   181  			Table: TableName{Name: NewIdentifierCS("a")},
   182  		},
   183  		affected: []string{"a"},
   184  	}, {
   185  		query: "rename table a to b",
   186  		output: &RenameTable{
   187  			TablePairs: []*RenameTablePair{
   188  				{
   189  					FromTable: TableName{Name: NewIdentifierCS("a")},
   190  					ToTable:   TableName{Name: NewIdentifierCS("b")},
   191  				},
   192  			},
   193  		},
   194  		affected: []string{"a", "b"},
   195  	}, {
   196  		query: "rename table a to b, c to d",
   197  		output: &RenameTable{
   198  			TablePairs: []*RenameTablePair{
   199  				{
   200  					FromTable: TableName{Name: NewIdentifierCS("a")},
   201  					ToTable:   TableName{Name: NewIdentifierCS("b")},
   202  				}, {
   203  					FromTable: TableName{Name: NewIdentifierCS("c")},
   204  					ToTable:   TableName{Name: NewIdentifierCS("d")},
   205  				},
   206  			},
   207  		},
   208  		affected: []string{"a", "b", "c", "d"},
   209  	}, {
   210  		query: "drop table a",
   211  		output: &DropTable{
   212  			FromTables: TableNames{
   213  				TableName{Name: NewIdentifierCS("a")},
   214  			},
   215  		},
   216  		affected: []string{"a"},
   217  	}, {
   218  		query: "drop table a, b",
   219  		output: &DropTable{
   220  			FromTables: TableNames{
   221  				TableName{Name: NewIdentifierCS("a")},
   222  				TableName{Name: NewIdentifierCS("b")},
   223  			},
   224  		},
   225  		affected: []string{"a", "b"},
   226  	}}
   227  	for _, tcase := range testcases {
   228  		got, err := Parse(tcase.query)
   229  		if err != nil {
   230  			t.Fatal(err)
   231  		}
   232  		if !reflect.DeepEqual(got, tcase.output) {
   233  			t.Errorf("%s: %v, want %v", tcase.query, got, tcase.output)
   234  		}
   235  		want := make(TableNames, 0, len(tcase.affected))
   236  		for _, t := range tcase.affected {
   237  			want = append(want, TableName{Name: NewIdentifierCS(t)})
   238  		}
   239  		if affected := got.(DDLStatement).AffectedTables(); !reflect.DeepEqual(affected, want) {
   240  			t.Errorf("Affected(%s): %v, want %v", tcase.query, affected, want)
   241  		}
   242  	}
   243  }
   244  
   245  func TestSetAutocommitON(t *testing.T) {
   246  	stmt, err := Parse("SET autocommit=ON")
   247  	require.NoError(t, err)
   248  	s, ok := stmt.(*Set)
   249  	if !ok {
   250  		t.Errorf("SET statement is not Set: %T", s)
   251  	}
   252  
   253  	if len(s.Exprs) < 1 {
   254  		t.Errorf("SET statement has no expressions")
   255  	}
   256  
   257  	e := s.Exprs[0]
   258  	switch v := e.Expr.(type) {
   259  	case *Literal:
   260  		if v.Type != StrVal {
   261  			t.Errorf("SET statement value is not StrVal: %T", v)
   262  		}
   263  
   264  		if "on" != v.Val {
   265  			t.Errorf("SET statement value want: on, got: %s", v.Val)
   266  		}
   267  	default:
   268  		t.Errorf("SET statement expression is not Literal: %T", e.Expr)
   269  	}
   270  
   271  	stmt, err = Parse("SET @@session.autocommit=ON")
   272  	require.NoError(t, err)
   273  	s, ok = stmt.(*Set)
   274  	if !ok {
   275  		t.Errorf("SET statement is not Set: %T", s)
   276  	}
   277  
   278  	if len(s.Exprs) < 1 {
   279  		t.Errorf("SET statement has no expressions")
   280  	}
   281  
   282  	e = s.Exprs[0]
   283  	switch v := e.Expr.(type) {
   284  	case *Literal:
   285  		if v.Type != StrVal {
   286  			t.Errorf("SET statement value is not StrVal: %T", v)
   287  		}
   288  
   289  		if "on" != v.Val {
   290  			t.Errorf("SET statement value want: on, got: %s", v.Val)
   291  		}
   292  	default:
   293  		t.Errorf("SET statement expression is not Literal: %T", e.Expr)
   294  	}
   295  }
   296  
   297  func TestSetAutocommitOFF(t *testing.T) {
   298  	stmt, err := Parse("SET autocommit=OFF")
   299  	require.NoError(t, err)
   300  	s, ok := stmt.(*Set)
   301  	if !ok {
   302  		t.Errorf("SET statement is not Set: %T", s)
   303  	}
   304  
   305  	if len(s.Exprs) < 1 {
   306  		t.Errorf("SET statement has no expressions")
   307  	}
   308  
   309  	e := s.Exprs[0]
   310  	switch v := e.Expr.(type) {
   311  	case *Literal:
   312  		if v.Type != StrVal {
   313  			t.Errorf("SET statement value is not StrVal: %T", v)
   314  		}
   315  
   316  		if "off" != v.Val {
   317  			t.Errorf("SET statement value want: on, got: %s", v.Val)
   318  		}
   319  	default:
   320  		t.Errorf("SET statement expression is not Literal: %T", e.Expr)
   321  	}
   322  
   323  	stmt, err = Parse("SET @@session.autocommit=OFF")
   324  	require.NoError(t, err)
   325  	s, ok = stmt.(*Set)
   326  	if !ok {
   327  		t.Errorf("SET statement is not Set: %T", s)
   328  	}
   329  
   330  	if len(s.Exprs) < 1 {
   331  		t.Errorf("SET statement has no expressions")
   332  	}
   333  
   334  	e = s.Exprs[0]
   335  	switch v := e.Expr.(type) {
   336  	case *Literal:
   337  		if v.Type != StrVal {
   338  			t.Errorf("SET statement value is not StrVal: %T", v)
   339  		}
   340  
   341  		if "off" != v.Val {
   342  			t.Errorf("SET statement value want: on, got: %s", v.Val)
   343  		}
   344  	default:
   345  		t.Errorf("SET statement expression is not Literal: %T", e.Expr)
   346  	}
   347  
   348  }
   349  
   350  func TestWhere(t *testing.T) {
   351  	var w *Where
   352  	buf := NewTrackedBuffer(nil)
   353  	w.Format(buf)
   354  	if buf.String() != "" {
   355  		t.Errorf("w.Format(nil): %q, want \"\"", buf.String())
   356  	}
   357  	w = NewWhere(WhereClause, nil)
   358  	buf = NewTrackedBuffer(nil)
   359  	w.Format(buf)
   360  	if buf.String() != "" {
   361  		t.Errorf("w.Format(&Where{nil}: %q, want \"\"", buf.String())
   362  	}
   363  }
   364  
   365  func TestIsAggregate(t *testing.T) {
   366  	f := FuncExpr{Name: NewIdentifierCI("avg")}
   367  	if !f.IsAggregate() {
   368  		t.Error("IsAggregate: false, want true")
   369  	}
   370  
   371  	f = FuncExpr{Name: NewIdentifierCI("Avg")}
   372  	if !f.IsAggregate() {
   373  		t.Error("IsAggregate: false, want true")
   374  	}
   375  
   376  	f = FuncExpr{Name: NewIdentifierCI("foo")}
   377  	if f.IsAggregate() {
   378  		t.Error("IsAggregate: true, want false")
   379  	}
   380  }
   381  
   382  func TestIsImpossible(t *testing.T) {
   383  	f := ComparisonExpr{
   384  		Operator: NotEqualOp,
   385  		Left:     NewIntLiteral("1"),
   386  		Right:    NewIntLiteral("1"),
   387  	}
   388  	if !f.IsImpossible() {
   389  		t.Error("IsImpossible: false, want true")
   390  	}
   391  
   392  	f = ComparisonExpr{
   393  		Operator: EqualOp,
   394  		Left:     NewIntLiteral("1"),
   395  		Right:    NewIntLiteral("1"),
   396  	}
   397  	if f.IsImpossible() {
   398  		t.Error("IsImpossible: true, want false")
   399  	}
   400  
   401  	f = ComparisonExpr{
   402  		Operator: NotEqualOp,
   403  		Left:     NewIntLiteral("1"),
   404  		Right:    NewIntLiteral("2"),
   405  	}
   406  	if f.IsImpossible() {
   407  		t.Error("IsImpossible: true, want false")
   408  	}
   409  }
   410  
   411  func TestReplaceExpr(t *testing.T) {
   412  	tcases := []struct {
   413  		in, out string
   414  	}{{
   415  		in:  "select * from t where (select a from b)",
   416  		out: ":a",
   417  	}, {
   418  		in:  "select * from t where (select a from b) and b",
   419  		out: ":a and b",
   420  	}, {
   421  		in:  "select * from t where a and (select a from b)",
   422  		out: "a and :a",
   423  	}, {
   424  		in:  "select * from t where (select a from b) or b",
   425  		out: ":a or b",
   426  	}, {
   427  		in:  "select * from t where a or (select a from b)",
   428  		out: "a or :a",
   429  	}, {
   430  		in:  "select * from t where not (select a from b)",
   431  		out: "not :a",
   432  	}, {
   433  		in:  "select * from t where ((select a from b))",
   434  		out: ":a",
   435  	}, {
   436  		in:  "select * from t where (select a from b) = 1",
   437  		out: ":a = 1",
   438  	}, {
   439  		in:  "select * from t where a = (select a from b)",
   440  		out: "a = :a",
   441  	}, {
   442  		in:  "select * from t where a like b escape (select a from b)",
   443  		out: "a like b escape :a",
   444  	}, {
   445  		in:  "select * from t where (select a from b) between a and b",
   446  		out: ":a between a and b",
   447  	}, {
   448  		in:  "select * from t where a between (select a from b) and b",
   449  		out: "a between :a and b",
   450  	}, {
   451  		in:  "select * from t where a between b and (select a from b)",
   452  		out: "a between b and :a",
   453  	}, {
   454  		in:  "select * from t where (select a from b) is null",
   455  		out: ":a is null",
   456  	}, {
   457  		// exists should not replace.
   458  		in:  "select * from t where exists (select a from b)",
   459  		out: "exists (select a from b)",
   460  	}, {
   461  		in:  "select * from t where a in ((select a from b), 1)",
   462  		out: "a in (:a, 1)",
   463  	}, {
   464  		in:  "select * from t where a in (0, (select a from b), 1)",
   465  		out: "a in (0, :a, 1)",
   466  	}, {
   467  		in:  "select * from t where (select a from b) + 1",
   468  		out: ":a + 1",
   469  	}, {
   470  		in:  "select * from t where 1+(select a from b)",
   471  		out: "1 + :a",
   472  	}, {
   473  		in:  "select * from t where -(select a from b)",
   474  		out: "-:a",
   475  	}, {
   476  		in:  "select * from t where interval (select a from b) aa",
   477  		out: "interval :a aa",
   478  	}, {
   479  		in:  "select * from t where (select a from b) collate utf8",
   480  		out: ":a collate utf8",
   481  	}, {
   482  		in:  "select * from t where func((select a from b), 1)",
   483  		out: "func(:a, 1)",
   484  	}, {
   485  		in:  "select * from t where func(1, (select a from b), 1)",
   486  		out: "func(1, :a, 1)",
   487  	}, {
   488  		in:  "select * from t where group_concat((select a from b), 1 order by a)",
   489  		out: "group_concat(:a, 1 order by a asc)",
   490  	}, {
   491  		in:  "select * from t where group_concat(1 order by (select a from b), a)",
   492  		out: "group_concat(1 order by :a asc, a asc)",
   493  	}, {
   494  		in:  "select * from t where group_concat(1 order by a, (select a from b))",
   495  		out: "group_concat(1 order by a asc, :a asc)",
   496  	}, {
   497  		in:  "select * from t where substr(a, (select a from b), b)",
   498  		out: "substr(a, :a, b)",
   499  	}, {
   500  		in:  "select * from t where substr(a, b, (select a from b))",
   501  		out: "substr(a, b, :a)",
   502  	}, {
   503  		in:  "select * from t where convert((select a from b), json)",
   504  		out: "convert(:a, json)",
   505  	}, {
   506  		in:  "select * from t where convert((select a from b) using utf8)",
   507  		out: "convert(:a using utf8)",
   508  	}, {
   509  		in:  "select * from t where case (select a from b) when a then b when b then c else d end",
   510  		out: "case :a when a then b when b then c else d end",
   511  	}, {
   512  		in:  "select * from t where case a when (select a from b) then b when b then c else d end",
   513  		out: "case a when :a then b when b then c else d end",
   514  	}, {
   515  		in:  "select * from t where case a when b then (select a from b) when b then c else d end",
   516  		out: "case a when b then :a when b then c else d end",
   517  	}, {
   518  		in:  "select * from t where case a when b then c when (select a from b) then c else d end",
   519  		out: "case a when b then c when :a then c else d end",
   520  	}, {
   521  		in:  "select * from t where case a when b then c when d then c else (select a from b) end",
   522  		out: "case a when b then c when d then c else :a end",
   523  	}}
   524  	to := NewArgument("a")
   525  	for _, tcase := range tcases {
   526  		tree, err := Parse(tcase.in)
   527  		if err != nil {
   528  			t.Fatal(err)
   529  		}
   530  		var from *Subquery
   531  		_ = Walk(func(node SQLNode) (kontinue bool, err error) {
   532  			if sq, ok := node.(*Subquery); ok {
   533  				from = sq
   534  				return false, nil
   535  			}
   536  			return true, nil
   537  		}, tree)
   538  		if from == nil {
   539  			t.Fatalf("from is nil for %s", tcase.in)
   540  		}
   541  		expr := ReplaceExpr(tree.(*Select).Where.Expr, from, to)
   542  		got := String(expr)
   543  		if tcase.out != got {
   544  			t.Errorf("ReplaceExpr(%s): %s, want %s", tcase.in, got, tcase.out)
   545  		}
   546  	}
   547  }
   548  
   549  func TestColNameEqual(t *testing.T) {
   550  	var c1, c2 *ColName
   551  	if c1.Equal(c2) {
   552  		t.Error("nil columns equal, want unequal")
   553  	}
   554  	c1 = &ColName{
   555  		Name: NewIdentifierCI("aa"),
   556  	}
   557  	c2 = &ColName{
   558  		Name: NewIdentifierCI("bb"),
   559  	}
   560  	if c1.Equal(c2) {
   561  		t.Error("columns equal, want unequal")
   562  	}
   563  	c2.Name = NewIdentifierCI("aa")
   564  	if !c1.Equal(c2) {
   565  		t.Error("columns unequal, want equal")
   566  	}
   567  }
   568  
   569  func TestIdentifierCI(t *testing.T) {
   570  	str := NewIdentifierCI("Ab")
   571  	if str.String() != "Ab" {
   572  		t.Errorf("String=%s, want Ab", str.String())
   573  	}
   574  	if str.String() != "Ab" {
   575  		t.Errorf("Val=%s, want Ab", str.String())
   576  	}
   577  	if str.Lowered() != "ab" {
   578  		t.Errorf("Val=%s, want ab", str.Lowered())
   579  	}
   580  	if !str.Equal(NewIdentifierCI("aB")) {
   581  		t.Error("str.Equal(NewIdentifierCI(aB))=false, want true")
   582  	}
   583  	if !str.EqualString("ab") {
   584  		t.Error("str.EqualString(ab)=false, want true")
   585  	}
   586  	str = NewIdentifierCI("")
   587  	if str.Lowered() != "" {
   588  		t.Errorf("Val=%s, want \"\"", str.Lowered())
   589  	}
   590  }
   591  
   592  func TestIdentifierCIMarshal(t *testing.T) {
   593  	str := NewIdentifierCI("Ab")
   594  	b, err := json.Marshal(str)
   595  	if err != nil {
   596  		t.Fatal(err)
   597  	}
   598  	got := string(b)
   599  	want := `"Ab"`
   600  	if got != want {
   601  		t.Errorf("json.Marshal()= %s, want %s", got, want)
   602  	}
   603  	var out IdentifierCI
   604  	if err := json.Unmarshal(b, &out); err != nil {
   605  		t.Errorf("Unmarshal err: %v, want nil", err)
   606  	}
   607  	if !reflect.DeepEqual(out, str) {
   608  		t.Errorf("Unmarshal: %v, want %v", out, str)
   609  	}
   610  }
   611  
   612  func TestIdentifierCISize(t *testing.T) {
   613  	size := unsafe.Sizeof(NewIdentifierCI(""))
   614  	want := 2 * unsafe.Sizeof("")
   615  	assert.Equal(t, want, size, "size of IdentifierCI")
   616  }
   617  
   618  func TestIdentifierCSMarshal(t *testing.T) {
   619  	str := NewIdentifierCS("Ab")
   620  	b, err := json.Marshal(str)
   621  	if err != nil {
   622  		t.Fatal(err)
   623  	}
   624  	got := string(b)
   625  	want := `"Ab"`
   626  	if got != want {
   627  		t.Errorf("json.Marshal()= %s, want %s", got, want)
   628  	}
   629  	var out IdentifierCS
   630  	if err := json.Unmarshal(b, &out); err != nil {
   631  		t.Errorf("Unmarshal err: %v, want nil", err)
   632  	}
   633  	if !reflect.DeepEqual(out, str) {
   634  		t.Errorf("Unmarshal: %v, want %v", out, str)
   635  	}
   636  }
   637  
   638  func TestHexDecode(t *testing.T) {
   639  	testcase := []struct {
   640  		in, out string
   641  	}{{
   642  		in:  "313233",
   643  		out: "123",
   644  	}, {
   645  		in:  "ag",
   646  		out: "encoding/hex: invalid byte: U+0067 'g'",
   647  	}, {
   648  		in:  "777",
   649  		out: "encoding/hex: odd length hex string",
   650  	}}
   651  	for _, tc := range testcase {
   652  		out, err := NewHexLiteral(tc.in).HexDecode()
   653  		if err != nil {
   654  			if err.Error() != tc.out {
   655  				t.Errorf("Decode(%q): %v, want %s", tc.in, err, tc.out)
   656  			}
   657  			continue
   658  		}
   659  		if !bytes.Equal(out, []byte(tc.out)) {
   660  			t.Errorf("Decode(%q): %s, want %s", tc.in, out, tc.out)
   661  		}
   662  	}
   663  }
   664  
   665  func TestCompliantName(t *testing.T) {
   666  	testcases := []struct {
   667  		in, out string
   668  	}{{
   669  		in:  "aa",
   670  		out: "aa",
   671  	}, {
   672  		in:  "1a",
   673  		out: "_a",
   674  	}, {
   675  		in:  "a1",
   676  		out: "a1",
   677  	}, {
   678  		in:  "a.b",
   679  		out: "a_b",
   680  	}, {
   681  		in:  ".ab",
   682  		out: "_ab",
   683  	}}
   684  	for _, tc := range testcases {
   685  		out := NewIdentifierCI(tc.in).CompliantName()
   686  		if out != tc.out {
   687  			t.Errorf("IdentifierCI(%s).CompliantNamt: %s, want %s", tc.in, out, tc.out)
   688  		}
   689  		out = NewIdentifierCS(tc.in).CompliantName()
   690  		if out != tc.out {
   691  			t.Errorf("IdentifierCS(%s).CompliantNamt: %s, want %s", tc.in, out, tc.out)
   692  		}
   693  	}
   694  }
   695  
   696  func TestColumns_FindColumn(t *testing.T) {
   697  	cols := Columns{NewIdentifierCI("a"), NewIdentifierCI("c"), NewIdentifierCI("b"), NewIdentifierCI("0")}
   698  
   699  	testcases := []struct {
   700  		in  string
   701  		out int
   702  	}{{
   703  		in:  "a",
   704  		out: 0,
   705  	}, {
   706  		in:  "b",
   707  		out: 2,
   708  	},
   709  		{
   710  			in:  "0",
   711  			out: 3,
   712  		},
   713  		{
   714  			in:  "f",
   715  			out: -1,
   716  		}}
   717  
   718  	for _, tc := range testcases {
   719  		val := cols.FindColumn(NewIdentifierCI(tc.in))
   720  		if val != tc.out {
   721  			t.Errorf("FindColumn(%s): %d, want %d", tc.in, val, tc.out)
   722  		}
   723  	}
   724  }
   725  
   726  func TestSplitStatementToPieces(t *testing.T) {
   727  	testcases := []struct {
   728  		input  string
   729  		output string
   730  	}{{
   731  		input:  "select * from table1; \t; \n; \n\t\t ;select * from table1;",
   732  		output: "select * from table1;select * from table1",
   733  	}, {
   734  		input: "select * from table",
   735  	}, {
   736  		input:  "select * from table;",
   737  		output: "select * from table",
   738  	}, {
   739  		input:  "select * from table;   ",
   740  		output: "select * from table",
   741  	}, {
   742  		input:  "select * from table1; select * from table2;",
   743  		output: "select * from table1; select * from table2",
   744  	}, {
   745  		input:  "select * from /* comment ; */ table;",
   746  		output: "select * from /* comment ; */ table",
   747  	}, {
   748  		input:  "select * from table where semi = ';';",
   749  		output: "select * from table where semi = ';'",
   750  	}, {
   751  		input:  "select * from table1;--comment;\nselect * from table2;",
   752  		output: "select * from table1;--comment;\nselect * from table2",
   753  	}, {
   754  		input: "CREATE TABLE `total_data` (`id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'id', " +
   755  			"`region` varchar(32) NOT NULL COMMENT 'region name, like zh; th; kepler'," +
   756  			"`data_size` bigint NOT NULL DEFAULT '0' COMMENT 'data size;'," +
   757  			"`createtime` datetime NOT NULL DEFAULT NOW() COMMENT 'create time;'," +
   758  			"`comment` varchar(100) NOT NULL DEFAULT '' COMMENT 'comment'," +
   759  			"PRIMARY KEY (`id`))",
   760  	}}
   761  
   762  	for _, tcase := range testcases {
   763  		t.Run(tcase.input, func(t *testing.T) {
   764  			if tcase.output == "" {
   765  				tcase.output = tcase.input
   766  			}
   767  
   768  			stmtPieces, err := SplitStatementToPieces(tcase.input)
   769  			require.NoError(t, err)
   770  
   771  			out := strings.Join(stmtPieces, ";")
   772  			require.Equal(t, tcase.output, out)
   773  		})
   774  	}
   775  }
   776  
   777  func TestTypeConversion(t *testing.T) {
   778  	ct1 := &ColumnType{Type: "BIGINT"}
   779  	ct2 := &ColumnType{Type: "bigint"}
   780  	assert.Equal(t, ct1.SQLType(), ct2.SQLType())
   781  }
   782  
   783  func TestDefaultStatus(t *testing.T) {
   784  	assert.Equal(t,
   785  		String(&Default{ColName: "status"}),
   786  		"default(`status`)")
   787  }
   788  
   789  func TestShowTableStatus(t *testing.T) {
   790  	query := "Show Table Status FROM customer"
   791  	tree, err := Parse(query)
   792  	require.NoError(t, err)
   793  	require.NotNil(t, tree)
   794  }
   795  
   796  func BenchmarkStringTraces(b *testing.B) {
   797  	for _, trace := range []string{"django_queries.txt", "lobsters.sql.gz"} {
   798  		b.Run(trace, func(b *testing.B) {
   799  			queries := loadQueries(b, trace)
   800  			if len(queries) > 10000 {
   801  				queries = queries[:10000]
   802  			}
   803  
   804  			parsed := make([]Statement, 0, len(queries))
   805  			for _, q := range queries {
   806  				pp, err := Parse(q)
   807  				if err != nil {
   808  					b.Fatal(err)
   809  				}
   810  				parsed = append(parsed, pp)
   811  			}
   812  
   813  			b.ResetTimer()
   814  			b.ReportAllocs()
   815  
   816  			for i := 0; i < b.N; i++ {
   817  				for _, stmt := range parsed {
   818  					_ = String(stmt)
   819  				}
   820  			}
   821  		})
   822  	}
   823  }