vitess.io/vitess@v0.16.2/go/vt/sqlparser/analyzer_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  )
    24  
    25  func TestPreview(t *testing.T) {
    26  	testcases := []struct {
    27  		sql  string
    28  		want StatementType
    29  	}{
    30  		{"select ...", StmtSelect},
    31  		{"    select ...", StmtSelect},
    32  		{"(select ...", StmtSelect},
    33  		{"( select ...", StmtSelect},
    34  		{"insert ...", StmtInsert},
    35  		{"replace ....", StmtReplace},
    36  		{"   update ...", StmtUpdate},
    37  		{"Update", StmtUpdate},
    38  		{"UPDATE ...", StmtUpdate},
    39  		{"\n\t    delete ...", StmtDelete},
    40  		{"", StmtUnknown},
    41  		{" ", StmtUnknown},
    42  		{"begin", StmtBegin},
    43  		{" begin", StmtBegin},
    44  		{" begin ", StmtBegin},
    45  		{"\n\t begin ", StmtBegin},
    46  		{"... begin ", StmtUnknown},
    47  		{"begin ...", StmtUnknown},
    48  		{"begin /* ... */", StmtBegin},
    49  		{"begin /* ... *//*test*/", StmtBegin},
    50  		{"begin;", StmtBegin},
    51  		{"begin ;", StmtBegin},
    52  		{"begin; /*...*/", StmtBegin},
    53  		{"start transaction", StmtBegin},
    54  		{"commit", StmtCommit},
    55  		{"commit /*...*/", StmtCommit},
    56  		{"rollback", StmtRollback},
    57  		{"rollback /*...*/", StmtRollback},
    58  		{"create", StmtDDL},
    59  		{"alter", StmtDDL},
    60  		{"rename", StmtDDL},
    61  		{"drop", StmtDDL},
    62  		{"set", StmtSet},
    63  		{"show", StmtShow},
    64  		{"use", StmtUse},
    65  		{"analyze", StmtOther},
    66  		{"describe", StmtExplain},
    67  		{"desc", StmtExplain},
    68  		{"explain", StmtExplain},
    69  		{"repair", StmtOther},
    70  		{"optimize", StmtOther},
    71  		{"grant", StmtPriv},
    72  		{"revoke", StmtPriv},
    73  		{"truncate", StmtDDL},
    74  		{"flush", StmtFlush},
    75  		{"unknown", StmtUnknown},
    76  
    77  		{"/* leading comment */ select ...", StmtSelect},
    78  		{"/* leading comment */ (select ...", StmtSelect},
    79  		{"/* leading comment */ /* leading comment 2 */ select ...", StmtSelect},
    80  		{"/*! MySQL-specific comment */", StmtComment},
    81  		{"/*!50708 MySQL-version comment */", StmtComment},
    82  		{"-- leading single line comment \n select ...", StmtSelect},
    83  		{"-- leading single line comment \n -- leading single line comment 2\n select ...", StmtSelect},
    84  
    85  		{"/* leading comment no end select ...", StmtUnknown},
    86  		{"-- leading single line comment no end select ...", StmtUnknown},
    87  		{"/*!40000 ALTER TABLE `t1` DISABLE KEYS */", StmtComment},
    88  	}
    89  	for _, tcase := range testcases {
    90  		if got := Preview(tcase.sql); got != tcase.want {
    91  			t.Errorf("Preview(%s): %v, want %v", tcase.sql, got, tcase.want)
    92  		}
    93  	}
    94  }
    95  
    96  func TestIsDML(t *testing.T) {
    97  	testcases := []struct {
    98  		sql  string
    99  		want bool
   100  	}{
   101  		{"   update ...", true},
   102  		{"Update", true},
   103  		{"UPDATE ...", true},
   104  		{"\n\t    delete ...", true},
   105  		{"insert ...", true},
   106  		{"replace ...", true},
   107  		{"select ...", false},
   108  		{"    select ...", false},
   109  		{"", false},
   110  		{" ", false},
   111  	}
   112  	for _, tcase := range testcases {
   113  		if got := IsDML(tcase.sql); got != tcase.want {
   114  			t.Errorf("IsDML(%s): %v, want %v", tcase.sql, got, tcase.want)
   115  		}
   116  	}
   117  }
   118  
   119  func TestSplitAndExpression(t *testing.T) {
   120  	testcases := []struct {
   121  		sql string
   122  		out []string
   123  	}{{
   124  		sql: "select * from t",
   125  		out: nil,
   126  	}, {
   127  		sql: "select * from t where a = 1",
   128  		out: []string{"a = 1"},
   129  	}, {
   130  		sql: "select * from t where a = 1 and b = 1",
   131  		out: []string{"a = 1", "b = 1"},
   132  	}, {
   133  		sql: "select * from t where a = 1 and (b = 1 and c = 1)",
   134  		out: []string{"a = 1", "b = 1", "c = 1"},
   135  	}, {
   136  		sql: "select * from t where a = 1 and (b = 1 or c = 1)",
   137  		out: []string{"a = 1", "b = 1 or c = 1"},
   138  	}, {
   139  		sql: "select * from t where a = 1 and b = 1 or c = 1",
   140  		out: []string{"a = 1 and b = 1 or c = 1"},
   141  	}, {
   142  		sql: "select * from t where a = 1 and b = 1 + (c = 1)",
   143  		out: []string{"a = 1", "b = 1 + (c = 1)"},
   144  	}, {
   145  		sql: "select * from t where (a = 1 and ((b = 1 and c = 1)))",
   146  		out: []string{"a = 1", "b = 1", "c = 1"},
   147  	}}
   148  	for _, tcase := range testcases {
   149  		stmt, err := Parse(tcase.sql)
   150  		assert.NoError(t, err)
   151  		var expr Expr
   152  		if where := stmt.(*Select).Where; where != nil {
   153  			expr = where.Expr
   154  		}
   155  		splits := SplitAndExpression(nil, expr)
   156  		var got []string
   157  		for _, split := range splits {
   158  			got = append(got, String(split))
   159  		}
   160  		assert.Equal(t, tcase.out, got)
   161  	}
   162  }
   163  
   164  func TestAndExpressions(t *testing.T) {
   165  	greaterThanExpr := &ComparisonExpr{
   166  		Operator: GreaterThanOp,
   167  		Left: &ColName{
   168  			Name: NewIdentifierCI("val"),
   169  			Qualifier: TableName{
   170  				Name: NewIdentifierCS("a"),
   171  			},
   172  		},
   173  		Right: &ColName{
   174  			Name: NewIdentifierCI("val"),
   175  			Qualifier: TableName{
   176  				Name: NewIdentifierCS("b"),
   177  			},
   178  		},
   179  	}
   180  	equalExpr := &ComparisonExpr{
   181  		Operator: EqualOp,
   182  		Left: &ColName{
   183  			Name: NewIdentifierCI("id"),
   184  			Qualifier: TableName{
   185  				Name: NewIdentifierCS("a"),
   186  			},
   187  		},
   188  		Right: &ColName{
   189  			Name: NewIdentifierCI("id"),
   190  			Qualifier: TableName{
   191  				Name: NewIdentifierCS("b"),
   192  			},
   193  		},
   194  	}
   195  	testcases := []struct {
   196  		name           string
   197  		expressions    Exprs
   198  		expectedOutput Expr
   199  	}{
   200  		{
   201  			name:           "empty input",
   202  			expressions:    nil,
   203  			expectedOutput: nil,
   204  		}, {
   205  			name: "two equal inputs",
   206  			expressions: Exprs{
   207  				greaterThanExpr,
   208  				equalExpr,
   209  				equalExpr,
   210  			},
   211  			expectedOutput: &AndExpr{
   212  				Left:  greaterThanExpr,
   213  				Right: equalExpr,
   214  			},
   215  		},
   216  		{
   217  			name: "two equal inputs",
   218  			expressions: Exprs{
   219  				equalExpr,
   220  				equalExpr,
   221  			},
   222  			expectedOutput: equalExpr,
   223  		},
   224  	}
   225  
   226  	for _, testcase := range testcases {
   227  		t.Run(testcase.name, func(t *testing.T) {
   228  			output := AndExpressions(testcase.expressions...)
   229  			assert.Equal(t, String(testcase.expectedOutput), String(output))
   230  		})
   231  	}
   232  }
   233  
   234  func TestTableFromStatement(t *testing.T) {
   235  	testcases := []struct {
   236  		in, out string
   237  	}{{
   238  		in:  "select * from t",
   239  		out: "t",
   240  	}, {
   241  		in:  "select * from t.t",
   242  		out: "t.t",
   243  	}, {
   244  		in:  "select * from t1, t2",
   245  		out: "table expression is complex",
   246  	}, {
   247  		in:  "select * from (t)",
   248  		out: "table expression is complex",
   249  	}, {
   250  		in:  "select * from t1 join t2",
   251  		out: "table expression is complex",
   252  	}, {
   253  		in:  "select * from (select * from t) as tt",
   254  		out: "table expression is complex",
   255  	}, {
   256  		in:  "update t set a=1",
   257  		out: "unrecognized statement: update t set a=1",
   258  	}, {
   259  		in:  "bad query",
   260  		out: "syntax error at position 4 near 'bad'",
   261  	}}
   262  
   263  	for _, tc := range testcases {
   264  		name, err := TableFromStatement(tc.in)
   265  		var got string
   266  		if err != nil {
   267  			got = err.Error()
   268  		} else {
   269  			got = String(name)
   270  		}
   271  		if got != tc.out {
   272  			t.Errorf("TableFromStatement('%s'): %s, want %s", tc.in, got, tc.out)
   273  		}
   274  	}
   275  }
   276  
   277  func TestGetTableName(t *testing.T) {
   278  	testcases := []struct {
   279  		in, out string
   280  	}{{
   281  		in:  "select * from t",
   282  		out: "t",
   283  	}, {
   284  		in:  "select * from t.t",
   285  		out: "",
   286  	}, {
   287  		in:  "select * from (select * from t) as tt",
   288  		out: "",
   289  	}}
   290  
   291  	for _, tc := range testcases {
   292  		tree, err := Parse(tc.in)
   293  		if err != nil {
   294  			t.Error(err)
   295  			continue
   296  		}
   297  		out := GetTableName(tree.(*Select).From[0].(*AliasedTableExpr).Expr)
   298  		if out.String() != tc.out {
   299  			t.Errorf("GetTableName('%s'): %s, want %s", tc.in, out, tc.out)
   300  		}
   301  	}
   302  }
   303  
   304  func TestIsColName(t *testing.T) {
   305  	testcases := []struct {
   306  		in  Expr
   307  		out bool
   308  	}{{
   309  		in:  &ColName{},
   310  		out: true,
   311  	}, {
   312  		in: NewHexLiteral(""),
   313  	}}
   314  	for _, tc := range testcases {
   315  		out := IsColName(tc.in)
   316  		if out != tc.out {
   317  			t.Errorf("IsColName(%T): %v, want %v", tc.in, out, tc.out)
   318  		}
   319  	}
   320  }
   321  
   322  func TestIsNull(t *testing.T) {
   323  	testcases := []struct {
   324  		in  Expr
   325  		out bool
   326  	}{{
   327  		in:  &NullVal{},
   328  		out: true,
   329  	}, {
   330  		in: NewStrLiteral(""),
   331  	}}
   332  	for _, tc := range testcases {
   333  		out := IsNull(tc.in)
   334  		if out != tc.out {
   335  			t.Errorf("IsNull(%T): %v, want %v", tc.in, out, tc.out)
   336  		}
   337  	}
   338  }